diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f564510..d58ff590 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,162 +1,5 @@ # Changelog -All notable changes to this project will be documented in this file. +All notable changes to this project will be documented as part of the release notes. -## [1.15.1] - 2023-12-07 -## Fixed -- Not receiving tokens when calling Get with options tokens as true. - -## [1.15.0] - 2023-10-30 -## Added -- options tokens support for Get method. - -## [1.14.0] - 2023-09-29 -## Added -- Support for different BYOT modes in Insert method. - -## [1.13.1] - 2023-09-14 -### Changed -- Add `request_index` in responses for insert method. - -## [1.13.0] - 2023-09-04 -### Added -- Added new Query method. - -## [1.12.0] - 2023-09-01 -### Added -- Support for Bulk request with Continue on Error in Detokenize Method -- Support for Continue on Error in Insert Method - -## [1.11.0] - 2023-08-25 -### Added -- Support for BYOT in Insert method. - -## [1.10.1] - 2023-07-28 -### Fixed -- Fixed delete method - -## [1.10.0] - 2023-07-21 -### Added -- Added delete method - -## [1.9.2] - 2023-06-22 -### Fixed -- Multiple record error in get method - -## [1.9.1] - 2023-06-07 -### Fixed -- Fixed bug in metrics - -## [1.9.0] - 2023-06-07 -### Added -- Added redaction type in detokenize - -## [1.8.1] - 2023-03-17 -### Removed -- removed grace period logic in bearer token generation - -## [1.8.0] - 2023-01-10 -### Added -- update and get methods. - -## [1.7.0] - 2022-12-07 -### Added -- `upsert` support for insert method. - -## [1.6.2] - 2022-06-28 - -### Added -- Copyright header to all files -- Security email in README - -## [1.6.1] - 2022-05-17 - -### Fixed - -- Insert with multiple records returning invalid output - -## [1.6.0] - 2022-04-12 - -### Added - -- support for application/x-www-form-urlencoded and multipart/form-data content-type's in connections. - -## [1.5.1] - 2022-03-29 - -### Added - -- Validation to token obtained from `tokenProvider` - -### Fixed - -- Request headers not getting overridden due to case sensitivity - -## [1.5.0] - 2022-03-22 - -### Changed - -- `getById` changed to `get_by_id` -- `invokeConnection`changed to `invoke_connection` -- `generateBearerToken` changed to `generate_bearer_token` -- `generateBearerTokenDromCreds` changed to `generate_bearer_token_from_creds` -- `isExpired` changed to `is_expired` -- `setLogLevel` changed to `set_log_level` - -### Removed - -- `isValid` function -- `GenerateToken` function - -## [1.4.0] - 2022-03-15 - -### Changed - -- deprecated `isValid` in favour of `isExpired` - -## [1.3.0] - 2022-02-24 - -### Added - -- Request ID in error logs and error responses for API Errors -- Caching to accessToken token -- `isValid` method for validating Service Account bearer token - -## [1.2.1] - 2022-01-18 - -### Fixed - -- `generateBearerTokenFromCreds` raising error "invalid credentials" on correct credentials - -## [1.2.0] - 2022-01-04 - -### Added - -- Logging functionality -- `setLogLevel` function for setting the package-level LogLevel -- `generateBearerTokenFromCreds` function which takes credentials as string - -### Changed - -- Renamed and deprecated `GenerateToken` in favor of `generateBearerToken` -- Make `vaultID` and `vaultURL` optional in `Client` constructor - -## [1.1.0] - 2021-11-10 - -### Added - -- `insert` vault API -- `detokenize` vault API -- `getById` vault API -- `invokeConnection` - -## [1.0.1] - 2021-10-26 - -### Changed - -- Package description - -## [1.0.0] - 2021-10-19 - -### Added - -- Service Account Token generation +See [Github](https://github.com/skyflowapi/skyflow-python/releases) or [PyPI](https://pypi.org/project/skyflow/#history) for more details on each released version. diff --git a/README.md b/README.md index b9d4ca86..cc2a78a4 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,9 @@ # Skyflow Python SDK +> **This is the current, recommended version of the Skyflow SDK.** V2.1.0 brings flexible auth, multi-vault support, native data types, and rich error diagnostics. +> +> Migrating from v1? See the **[Migration Guide](https://github.com/skyflowapi/skyflow-python/blob/main/docs/migrate_to_v2.md)** for step-by-step instructions. V1 is in maintenance mode and will reach End of Life on October 31, 2026. + The Skyflow Python SDK is designed to help with integrating Skyflow into a Python backend. ## Table of Contents @@ -235,8 +239,8 @@ from skyflow.utils.enums import RedactionType detokenize_request = DetokenizeRequest( data=[ - {'token': 'token1', 'redaction': RedactionType.PLAIN_TEXT}, - {'token': 'token2', 'redaction': RedactionType.PLAIN_TEXT} + {'token': 'token1', 'redaction_type': RedactionType.PLAIN_TEXT}, + {'token': 'token2', 'redaction_type': RedactionType.PLAIN_TEXT} ], continue_on_error=True ) @@ -406,7 +410,9 @@ Refer to [Query your data](https://docs.skyflow.com/query-data/) and [Execute Qu ### Upload File -Upload files to a Skyflow vault using the `upload_file` method. Create a file upload request with the `FileUploadRequest` class, which accepts parameters such as the table name, column name, and Skyflow ID. +Upload files to a Skyflow vault using the `upload_file` method. Create a file upload request with the `FileUploadRequest` class. + +**Upload a file to an existing record:** ```python from skyflow.vault.data import FileUploadRequest @@ -414,13 +420,26 @@ from skyflow.vault.data import FileUploadRequest # Open the file in binary read mode with open('path/to/file.pdf', 'rb') as file_obj: upload_request = FileUploadRequest( - table='documents', # Table name - column_name='attachment', # Column name to store file - skyflow_id='', # Skyflow ID of the record - file_object=file_obj # Pass file object + table='', + column_name='', + skyflow_id='', + file_object=file_obj ) - - # Perform File Upload + + response = skyflow_client.vault('').upload_file(upload_request) + print('File upload:', response) +``` + +**Upload a file and create a new record (omit `skyflow_id`):** + +```python +with open('path/to/file.pdf', 'rb') as file_obj: + upload_request = FileUploadRequest( + table='documents', + column_name='attachment', + file_object=file_obj + ) + response = skyflow_client.vault('').upload_file(upload_request) print('File upload:', response) ``` @@ -703,18 +722,65 @@ options = { Embed context values into a bearer token during generation so you can reference those values in your policies. This enables more flexible access controls, such as tracking end-user identity when making API calls using service accounts, and facilitates using signed data tokens during detokenization. -Generate bearer tokens containing context information using a service account with the context_id identifier. Context information is represented as a JWT claim in a Skyflow-generated bearer token. Tokens generated from such service accounts include a context_identifier claim, are valid for 60 minutes, and can be used to make API calls to the Data and Management APIs, depending on the service account's permissions. +Generate bearer tokens containing context information using a service account with the `context_id` identifier. Context information is represented as a JWT claim in a Skyflow-generated bearer token. Tokens generated from such service accounts include a `context_identifier` claim, are valid for 60 minutes, and can be used to make API calls to the Data and Management APIs, depending on the service account's permissions. + +The `ctx` parameter accepts either a **string** or a **dict**: + +**String context** — use when your policy references a single context value: + +```python +options = {'ctx': 'user_12345'} +token, _ = generate_bearer_token(filepath, options) +``` + +**Dict context** — use when your policy needs multiple context values for conditional data access. Each key in the dict maps to a Skyflow CEL policy variable under `request.context.*`: + +```python +options = { + 'ctx': { + 'role': 'admin', + 'department': 'finance', + 'user_id': 'user_12345', + } +} +token, _ = generate_bearer_token(filepath, options) +``` + +With the dict above, your Skyflow policies can reference `request.context.role`, `request.context.department`, and `request.context.user_id` to make conditional access decisions. + +Dict keys must contain only alphanumeric characters and underscores (`[a-zA-Z0-9_]`). Invalid keys will raise a `SkyflowError`. > [!TIP] -> See the full example in the samples directory: [token_generation_with_context_example.py](samples/service_account/token_generation_with_context_example.py) -> See [docs.skyflow.com](https://docs.skyflow.com) for more details on authentication, access control, and governance for Skyflow. +> See the full example in the samples directory: [token_generation_with_context_example.py](samples/service_account/token_generation_with_context_example.py) +> See Skyflow's [context-aware authorization](https://docs.skyflow.com) and [conditional data access](https://docs.skyflow.com) docs for policy variable syntax like `request.context.*`. #### Generate signed data tokens: `generate_signed_data_tokens(filepath, options)` Digitally sign data tokens with a service account's private key to add an extra layer of protection. Skyflow generates data tokens when sensitive data is inserted into the vault. Detokenize signed tokens only by providing the signed data token along with a bearer token generated from the service account's credentials. The service account must have the necessary permissions and context to successfully detokenize the signed data tokens. +The `ctx` parameter on signed data tokens also accepts either a **string** or a **dict**, using the same format as bearer tokens: + +```python +# String context +options = { + 'ctx': 'user_12345', + 'data_tokens': ['dataToken1', 'dataToken2'], + 'time_to_live': 90, +} + +# Dict context +options = { + 'ctx': { + 'role': 'analyst', + 'department': 'research', + }, + 'data_tokens': ['dataToken1', 'dataToken2'], + 'time_to_live': 90, +} +``` + > [!TIP] -> See the full example in the samples directory: [signed_token_generation_example.py](samples/service_account/signed_token_generation_example.py) +> See the full example in the samples directory: [signed_token_generation_example.py](samples/service_account/signed_token_generation_example.py) > See [docs.skyflow.com](https://docs.skyflow.com) for more details on authentication, access control, and governance for Skyflow. ## Logging diff --git a/ruff.toml b/ruff.toml index b6795704..aea6cce7 100644 --- a/ruff.toml +++ b/ruff.toml @@ -8,12 +8,13 @@ exclude = [ "venv", "build", "dist", - "tests" + "tests", + "samples" ] line-length = 120 [lint] -select = ["N"] +select = ["N", "PLR2004"] [lint.pep8-naming] diff --git a/samples/deprecated/detokenize_redaction_key.py b/samples/deprecated/detokenize_redaction_key.py new file mode 100644 index 00000000..63da46b8 --- /dev/null +++ b/samples/deprecated/detokenize_redaction_key.py @@ -0,0 +1,94 @@ +import json +from skyflow.error import SkyflowError +from skyflow import Env +from skyflow import Skyflow, LogLevel +from skyflow.utils.enums import RedactionType +from skyflow.vault.tokens import DetokenizeRequest + +""" + * [DEPRECATED] Skyflow Detokenize with 'redaction' Key Example + * + * The 'redaction' key in detokenize data is deprecated. + * Use 'redaction_type' instead. + * + * This example demonstrates how to: + * 1. Configure Skyflow client credentials + * 2. Set up vault configuration + * 3. Detokenize using the deprecated 'redaction' key + * 4. Handle response and errors + * + * Migration: + * Before: {'token': '', 'redaction': RedactionType.PLAIN_TEXT} + * After: {'token': '', 'redaction_type': RedactionType.PLAIN_TEXT} +""" + +def perform_detokenization_deprecated(): + try: + # Step 1: Configure Credentials + cred = { + 'clientID': '', + 'clientName': '', + 'tokenURI': '', + 'keyID': '', + 'privateKey': '', + } + + skyflow_credentials = { + 'credentials_string': json.dumps(cred) + } + + credentials = { + 'token': '' + } + + # Step 2: Configure Vault + primary_vault_config = { + 'vault_id': '', + 'cluster_id': '', + 'env': Env.PROD, + 'credentials': credentials + } + + # Step 3: Configure & Initialize Skyflow Client + skyflow_client = ( + Skyflow.builder() + .add_vault_config(primary_vault_config) + .add_skyflow_credentials(skyflow_credentials) + .set_log_level(LogLevel.ERROR) + .build() + ) + + # Step 4: Prepare Detokenization Data using deprecated 'redaction' key + # DEPRECATED: 'redaction' key will be removed in a future release. + # Use 'redaction_type' instead (see migration above). + detokenize_data = [ + { + 'token': '', + 'redaction': RedactionType.REDACTED + }, + { + 'token': '', + 'redaction': RedactionType.MASKED + } + ] + + detokenize_request = DetokenizeRequest( + data=detokenize_data, + continue_on_error=True + ) + + # Step 5: Perform Detokenization + response = skyflow_client.vault(primary_vault_config.get('vault_id')).detokenize(detokenize_request) + print('Detokenization successful: ', response) + + except SkyflowError as error: + print('Skyflow Specific Error: ', { + 'code': error.http_code, + 'message': error.message, + 'details': error.details + }) + except Exception as error: + print('Unexpected Error:', error) + +# Invoke the function +perform_detokenization_deprecated() diff --git a/samples/deprecated/file_upload_positional_args.py b/samples/deprecated/file_upload_positional_args.py new file mode 100644 index 00000000..e5edf50b --- /dev/null +++ b/samples/deprecated/file_upload_positional_args.py @@ -0,0 +1,85 @@ +import json +from skyflow.error import SkyflowError +from skyflow import Env +from skyflow import Skyflow, LogLevel +from skyflow.vault.data import FileUploadRequest + +""" + * [DEPRECATED] Skyflow FileUploadRequest Positional Arguments Example + * + * Passing positional arguments after 'table' in FileUploadRequest is deprecated. + * Use keyword arguments instead. + * + * This example demonstrates how to: + * 1. Configure Skyflow client credentials + * 2. Set up vault configuration + * 3. Upload a file using the deprecated positional argument order + * 4. Handle response and errors + * + * Migration: + * Before: FileUploadRequest(table, skyflow_id, column_name, file_object=file_obj) + * After: FileUploadRequest(table, column_name=column_name, skyflow_id=skyflow_id, file_object=file_obj) +""" + +def perform_file_upload_deprecated(): + try: + # Step 1: Configure Credentials + cred = { + 'clientID': '', + 'clientName': '', + 'tokenURI': '', + 'keyID': '', + 'privateKey': '', + } + + skyflow_credentials = { + 'credentials_string': json.dumps(cred) + } + + credentials = { + 'token': '' + } + + # Step 2: Configure Vault + primary_vault_config = { + 'vault_id': '', + 'cluster_id': '', + 'env': Env.PROD, + 'credentials': credentials + } + + # Step 3: Configure & Initialize Skyflow Client + skyflow_client = ( + Skyflow.builder() + .add_vault_config(primary_vault_config) + .add_skyflow_credentials(skyflow_credentials) + .set_log_level(LogLevel.ERROR) + .build() + ) + + # Step 4: Perform File Upload using deprecated positional argument order + with open('', 'rb') as file_obj: + # DEPRECATED: positional args after 'table' are deprecated. + # Old order: (table, skyflow_id, column_name) + # Use keyword arguments instead (see migration above). + upload_request = FileUploadRequest( + '', + '', + '', + file_object=file_obj + ) + + response = skyflow_client.vault('').upload_file(upload_request) + print('File upload successful: ', response) + + except SkyflowError as error: + print('Skyflow Specific Error: ', { + 'code': error.http_code, + 'message': error.message, + 'details': error.details + }) + except Exception as error: + print('Unexpected Error:', error) + +# Invoke the function +perform_file_upload_deprecated() diff --git a/samples/deprecated/update_log_level.py b/samples/deprecated/update_log_level.py new file mode 100644 index 00000000..54f28a8f --- /dev/null +++ b/samples/deprecated/update_log_level.py @@ -0,0 +1,75 @@ +import json +from skyflow.error import SkyflowError +from skyflow import Env +from skyflow import Skyflow, LogLevel + +""" + * [DEPRECATED] Skyflow update_log_level Example + * + * update_log_level() is deprecated. Use set_log_level() instead. + * + * This example demonstrates how to: + * 1. Configure Skyflow client credentials + * 2. Set up vault configuration + * 3. Change the log level at runtime using the deprecated update_log_level() + * 4. Handle response and errors + * + * Migration: + * Before: skyflow_client.update_log_level(LogLevel.INFO) + * After: skyflow_client.set_log_level(LogLevel.INFO) +""" + +def perform_update_log_level(): + try: + # Step 1: Configure Credentials + cred = { + 'clientID': '', + 'clientName': '', + 'tokenURI': '', + 'keyID': '', + 'privateKey': '', + } + + skyflow_credentials = { + 'credentials_string': json.dumps(cred) + } + + credentials = { + 'token': '' + } + + # Step 2: Configure Vault + primary_vault_config = { + 'vault_id': '', + 'cluster_id': '', + 'env': Env.PROD, + 'credentials': credentials + } + + # Step 3: Configure & Initialize Skyflow Client + skyflow_client = ( + Skyflow.builder() + .add_vault_config(primary_vault_config) + .add_skyflow_credentials(skyflow_credentials) + .set_log_level(LogLevel.ERROR) + .build() + ) + + # Step 4: Change log level at runtime using deprecated method + # DEPRECATED: update_log_level() will be removed in a future release. + # Use set_log_level() instead. + skyflow_client.update_log_level(LogLevel.INFO) + + print('Log level updated successfully (deprecated). Use set_log_level() instead.') + + except SkyflowError as error: + print('Skyflow Specific Error: ', { + 'code': error.http_code, + 'message': error.message, + 'details': error.details + }) + except Exception as error: + print('Unexpected Error:', error) + +# Invoke the function +perform_update_log_level() diff --git a/samples/detect_api/deidentify_file.py b/samples/detect_api/deidentify_file.py index 99b4b26e..88f012c9 100644 --- a/samples/detect_api/deidentify_file.py +++ b/samples/detect_api/deidentify_file.py @@ -1,7 +1,14 @@ from skyflow.error import SkyflowError from skyflow import Env, Skyflow, LogLevel from skyflow.utils.enums import DetectEntities, MaskingMethod, DetectOutputTranscriptions -from skyflow.vault.detect import DeidentifyFileRequest, TokenFormat, Transformations, DateTransformation, Bleep, FileInput +from skyflow.vault.detect import ( + DeidentifyFileRequest, + TokenFormat, + Transformations, + DateTransformation, + Bleep, + FileInput, +) """ * Skyflow Deidentify File Example @@ -11,6 +18,7 @@ * spreadsheets, presentations, structured text. """ + def perform_file_deidentification(): try: # Step 1: Configure Credentials @@ -23,7 +31,7 @@ def perform_file_deidentification(): 'vault_id': '', # Replace with your vault ID 'cluster_id': '', # Replace with your cluster ID 'env': Env.PROD, # Deployment environment - 'credentials': credentials + 'credentials': credentials, } # Step 3: Configure & Initialize Skyflow Client @@ -36,70 +44,66 @@ def perform_file_deidentification(): # Step 4: Create File Object file_path = '' # Replace with your file path - file = open(file_path, 'rb') - # Step 5: Configure Deidentify File Request with all options - deidentify_request = DeidentifyFileRequest( - file=FileInput(file), # File to de-identify (can also provide a file path) - entities=[DetectEntities.SSN, DetectEntities.CREDIT_CARD], # Entities to detect - allow_regex_list=[''], # Optional: Patterns to allow - restrict_regex_list=[''], # Optional: Patterns to restrict - - # Token format configuration - token_format=TokenFormat( - vault_token=[DetectEntities.SSN], # Use vault tokens for these entities - ), - - # Optional: Custom transformations - # transformations=Transformations( - # shift_dates=DateTransformation( - # max_days=30, - # min_days=10, - # entities=[DetectEntities.DOB] - # ) - # ), - - # Output configuration - output_directory='', # Where to save processed file - wait_time=15, # Max wait time in seconds (max 64) - - # Image-specific options - output_processed_image=True, # Include processed image in output - output_ocr_text=True, # Include OCR text in response - masking_method=MaskingMethod.BLACKBOX, # Masking method for images - - # PDF-specific options - pixel_density=15, # Pixel density for PDF processing - max_resolution=2000, # Max resolution for PDF - # Audio-specific options - output_processed_audio=True, # Include processed audio - output_transcription=DetectOutputTranscriptions.PLAINTEXT_TRANSCRIPTION, # Transcription type - - # Audio bleep configuration - - # bleep=Bleep( - # gain=5, # Loudness in dB - # frequency=1000, # Pitch in Hz - # start_padding=0.1, # Padding at start (seconds) - # stop_padding=0.2 # Padding at end (seconds) - # ) - ) - - # Step 6: Call deidentifyFile API - response = skyflow_client.detect().deidentify_file(deidentify_request) + # Step 5: Configure Deidentify File Request and call API + with open(file_path, 'rb') as file: + deidentify_request = DeidentifyFileRequest( + file=FileInput(file), # File to de-identify (can also provide a file path) + entities=[DetectEntities.SSN, DetectEntities.CREDIT_CARD], # Entities to detect + allow_regex_list=[''], # Optional: Patterns to allow + restrict_regex_list=[''], # Optional: Patterns to restrict + # Token format configuration + token_format=TokenFormat( + vault_token=[DetectEntities.SSN], # Use vault tokens for these entities + ), + # Optional: Custom transformations + # transformations=Transformations( + # shift_dates=DateTransformation( + # max_days=30, + # min_days=10, + # entities=[DetectEntities.DOB] + # ) + # ), + # Output configuration + output_directory='', # Where to save processed file + wait_time=15, # Max wait time in seconds (max 64) + # Image-specific options + output_processed_image=True, # Include processed image in output + output_ocr_text=True, # Include OCR text in response + masking_method=MaskingMethod.BLACKBOX, # Masking method for images + # PDF-specific options + pixel_density=15, # Pixel density for PDF processing + max_resolution=2000, # Max resolution for PDF + # Audio-specific options + output_processed_audio=True, # Include processed audio + output_transcription=DetectOutputTranscriptions.PLAINTEXT_TRANSCRIPTION, # Transcription type + # Audio bleep configuration + # bleep=Bleep( + # gain=5, # Loudness in dB + # frequency=1000, # Pitch in Hz + # start_padding=0.1, # Padding at start (seconds) + # stop_padding=0.2 # Padding at end (seconds) + # ) + ) + + # Step 6: Call deidentifyFile API + response = skyflow_client.detect().deidentify_file(deidentify_request) # Handle Successful Response - print("\nDeidentify File Response:", response) + print('\nDeidentify File Response:', response) except SkyflowError as error: # Handle Skyflow-specific errors - print('\nSkyflow Error:', { - 'http_code': error.http_code, - 'grpc_code': error.grpc_code, - 'http_status': error.http_status, - 'message': error.message, - 'details': error.details - }) + print( + '\nSkyflow Error:', + { + 'http_code': error.http_code, + 'grpc_code': error.grpc_code, + 'http_status': error.http_status, + 'message': error.message, + 'details': error.details, + }, + ) except Exception as error: # Handle unexpected errors print('Unexpected Error:', error) diff --git a/samples/detect_api/deidentify_file_async.py b/samples/detect_api/deidentify_file_async.py new file mode 100644 index 00000000..7ff2ac13 --- /dev/null +++ b/samples/detect_api/deidentify_file_async.py @@ -0,0 +1,124 @@ +from skyflow.error import SkyflowError +from skyflow import Env, Skyflow, LogLevel +from skyflow.utils.enums import DetectEntities, MaskingMethod, DetectOutputTranscriptions +from skyflow.vault.detect import ( + DeidentifyFileRequest, + TokenFormat, + Transformations, + DateTransformation, + Bleep, + FileInput, +) +from concurrent.futures import ThreadPoolExecutor + +""" + * Skyflow Deidentify File Example + * + * This sample demonstrates how to use all available options for deidentifying files + * using an asynchronous approach. + * Supported file types: images (jpg, png, etc.), pdf, audio (mp3, wav), documents, + * spreadsheets, presentations, structured text. +""" + + +def perform_file_deidentification_async(): + try: + # Step 1: Configure Credentials + credentials = { + 'path': '/path/to/credentials.json' # Path to credentials file + } + + # Step 2: Configure Vault + vault_config = { + 'vault_id': '', # Replace with your vault ID + 'cluster_id': '', # Replace with your cluster ID + 'env': Env.PROD, # Deployment environment + 'credentials': credentials, + } + + # Step 3: Configure & Initialize Skyflow Client + skyflow_client = ( + Skyflow.builder() + .add_vault_config(vault_config) + .set_log_level(LogLevel.INFO) # Use LogLevel.ERROR in production + .build() + ) + + # Step 4: Create File Object + file_path = '' # Replace with your file path + + deidentify_request = DeidentifyFileRequest( + file=FileInput(file_path=file_path), # File to de-identify + # entities=[DetectEntities.SSN, DetectEntities.CREDIT_CARD], # Entities to detect + allow_regex_list=[''], # Optional: Patterns to allow + restrict_regex_list=[''], # Optional: Patterns to restrict + # Token format configuration + token_format=TokenFormat( + vault_token=[DetectEntities.SSN], # Use vault tokens for these entities + ), + # Optional: Custom transformations + # transformations=Transformations( + # shift_dates=DateTransformation( + # max_days=30, + # min_days=10, + # entities=[DetectEntities.DOB] + # ) + # ), + # Output configuration + output_directory='', # Where to save processed file + wait_time=15, # Max wait time in seconds (max 64) + # Image-specific options + output_processed_image=True, # Include processed image in output + output_ocr_text=True, # Include OCR text in response + masking_method=MaskingMethod.BLACKBOX, # Masking method for images + # PDF-specific options + pixel_density=15, # Pixel density for PDF processing + max_resolution=2000, # Max resolution for PDF + # Audio-specific options + output_processed_audio=True, # Include processed audio + output_transcription=DetectOutputTranscriptions.PLAINTEXT_TRANSCRIPTION, # Transcription type + # Audio bleep configuration + # bleep=Bleep( + # gain=5, # Loudness in dB + # frequency=1000, # Pitch in Hz + # start_padding=0.1, # Padding at start (seconds) + # stop_padding=0.2 # Padding at end (seconds) + # ) + ) + + # Create a thread pool executor + executor = ThreadPoolExecutor(max_workers=1) + + future = executor.submit(lambda: skyflow_client.detect().deidentify_file(deidentify_request)) + + def handle_response(future): + exception = future.exception() + if exception is not None: + if isinstance(exception, SkyflowError): + # Handle Skyflow-specific errors + print( + '\nSkyflow Error:', + { + 'http_code': exception.http_code, + 'grpc_code': exception.grpc_code, + 'http_status': exception.http_status, + 'message': exception.message, + 'details': exception.details, + }, + ) + else: + # Handle unexpected errors + print('Unexpected Error:', exception) + return + + # Handle Successful Response + result = future.result() + print('\nDeidentify File Response:', result) + + future.add_done_callback(handle_response) + + executor.shutdown(wait=True) + + except Exception as error: + # Handle unexpected errors + print('Unexpected Error:', error) diff --git a/samples/service_account/signed_token_generation_example.py b/samples/service_account/signed_token_generation_example.py index 32140ada..7ae175cd 100644 --- a/samples/service_account/signed_token_generation_example.py +++ b/samples/service_account/signed_token_generation_example.py @@ -1,12 +1,10 @@ import json from skyflow.service_account import ( - is_expired, generate_signed_data_tokens, generate_signed_data_tokens_from_creds, ) -file_path = 'CREDENTIALS_FILE_PATH' -bearer_token = '' +file_path = '' skyflow_credentials = { 'clientID': '', @@ -18,42 +16,64 @@ credentials_string = json.dumps(skyflow_credentials) -options = { - 'ctx': 'CONTEXT_ID', - 'data_tokens': ['DATA_TOKEN1', 'DATA_TOKEN2'], - 'time_to_live': 90, # in seconds -} +# Approach 1: Signed data tokens with string context +# Returns: [('', ''), ...] +def get_signed_tokens_with_string_context(): + options = { + 'ctx': 'user_12345', + 'data_tokens': ['', ''], + 'time_to_live': 90, # in seconds + } + try: + results = generate_signed_data_tokens(file_path, options) + for data_token, signed_data_token in results: + print(f' Token: {data_token}, Signed Token: {signed_data_token}') + return results + except Exception as e: + print(f'Error: {str(e)}') -def get_signed_bearer_token_from_file_path(): - # Generate signed bearer token from credentials file path. - global bearer_token +# Approach 2: Signed data tokens with JSON object context (dict) +# Each key maps to a Skyflow CEL policy variable under request.context.* +# For example: request.context.role == "analyst" and request.context.department == "research" +def get_signed_tokens_with_object_context(): + options = { + 'ctx': { + 'role': 'analyst', + 'department': 'research', + 'user_id': 'user_67890', + }, + 'data_tokens': ['', ''], + 'time_to_live': 90, + } try: - if not is_expired(bearer_token): - return bearer_token - else: - data_token, signed_data_token = generate_signed_data_tokens(file_path, options) - return data_token, signed_data_token - + results = generate_signed_data_tokens(file_path, options) + for data_token, signed_data_token in results: + print(f' Token: {data_token}, Signed Token: {signed_data_token}') + return results except Exception as e: - print(f'Error generating token from file path: {str(e)}') + print(f'Error: {str(e)}') -def get_signed_bearer_token_from_credentials_string(): - # Generate signed bearer token from credentials string. - global bearer_token - +# Approach 3: Signed data tokens from credentials string +def get_signed_tokens_from_credentials_string(): + options = { + 'ctx': 'user_12345', + 'data_tokens': ['', ''], + 'time_to_live': 90, + } try: - if not is_expired(bearer_token): - return bearer_token - else: - data_token, signed_data_token = generate_signed_data_tokens_from_creds(credentials_string, options) - return data_token, signed_data_token - + results = generate_signed_data_tokens_from_creds(credentials_string, options) + for data_token, signed_data_token in results: + print(f' Token: {data_token}, Signed Token: {signed_data_token}') + return results except Exception as e: - print(f'Error generating token from credentials string: {str(e)}') - + print(f'Error: {str(e)}') -print(get_signed_bearer_token_from_file_path()) -print(get_signed_bearer_token_from_credentials_string()) +print('String context:') +get_signed_tokens_with_string_context() +print('Object context:') +get_signed_tokens_with_object_context() +print('Creds string:') +get_signed_tokens_from_credentials_string() diff --git a/samples/service_account/token_generation_example.py b/samples/service_account/token_generation_example.py index 34db4c37..32fa022b 100644 --- a/samples/service_account/token_generation_example.py +++ b/samples/service_account/token_generation_example.py @@ -5,7 +5,7 @@ is_expired, ) -file_path = 'CREDENTIALS_FILE_PATH' +file_path = '' bearer_token = '' # To generate Bearer Token from credentials string. @@ -46,10 +46,9 @@ def get_bearer_token_from_credentials_string(): bearer_token = token return bearer_token except Exception as e: - print(f"Error generating token from credentials string: {str(e)}") - + print(f'Error generating token from credentials string: {str(e)}') print(get_bearer_token_from_file_path()) -print(get_bearer_token_from_credentials_string()) \ No newline at end of file +print(get_bearer_token_from_credentials_string()) diff --git a/samples/service_account/token_generation_with_context_example.py b/samples/service_account/token_generation_with_context_example.py index a43a072a..03aa9f06 100644 --- a/samples/service_account/token_generation_with_context_example.py +++ b/samples/service_account/token_generation_with_context_example.py @@ -18,11 +18,13 @@ } credentials_string = json.dumps(skyflow_credentials) -options = {'ctx': ''} -def get_bearer_token_with_context_from_file_path(): - # Generate bearer token with context from credentials file path. +# Approach 1: Bearer token with string context +# Use a simple string identifier when your policy references a single context value. +# In your Skyflow policy, reference this as: request.context +def get_bearer_token_with_string_context(): global bearer_token + options = {'ctx': 'user_12345'} try: if not is_expired(bearer_token): @@ -31,14 +33,40 @@ def get_bearer_token_with_context_from_file_path(): token, _ = generate_bearer_token(file_path, options) bearer_token = token return bearer_token + except Exception as e: + print(f'Error generating token: {str(e)}') + + +# Approach 2: Bearer token with JSON object context (dict) +# Use a dict when your policy needs multiple context values for conditional data access. +# Each key maps to a Skyflow CEL policy variable under request.context.* +# For example: request.context.role == "admin" and request.context.department == "finance" +def get_bearer_token_with_object_context(): + global bearer_token + options = { + 'ctx': { + 'role': 'admin', + 'department': 'finance', + 'user_id': 'user_12345', + } + } + try: + if not is_expired(bearer_token): + return bearer_token + else: + token, _ = generate_bearer_token(file_path, options) + bearer_token = token + return bearer_token except Exception as e: - print(f'Error generating token from file path: {str(e)}') + print(f'Error generating token: {str(e)}') +# Approach 3: Bearer token with string context from credentials string def get_bearer_token_with_context_from_credentials_string(): - # Generate bearer token with context from credentials string. global bearer_token + options = {'ctx': 'user_12345'} + try: if not is_expired(bearer_token): return bearer_token @@ -47,9 +75,9 @@ def get_bearer_token_with_context_from_credentials_string(): bearer_token = token return bearer_token except Exception as e: - print(f"Error generating token from credentials string: {str(e)}") - + print(f"Error generating token: {str(e)}") -print(get_bearer_token_with_context_from_file_path()) -print(get_bearer_token_with_context_from_credentials_string()) \ No newline at end of file +print("String context:", get_bearer_token_with_string_context()) +print("Object context:", get_bearer_token_with_object_context()) +print("Creds string:", get_bearer_token_with_context_from_credentials_string()) diff --git a/samples/vault_api/credentials_options.py b/samples/vault_api/credentials_options.py index db792042..2155f99d 100644 --- a/samples/vault_api/credentials_options.py +++ b/samples/vault_api/credentials_options.py @@ -13,6 +13,7 @@ 4. Handle response and errors """ + def perform_secure_data_deletion(): try: # Step 1: Configure Bearer Token Credentials @@ -31,10 +32,10 @@ def perform_secure_data_deletion(): } secondary_vault_config = { - 'vault_id': 'YOUR_SECONDARY_VAULT_ID', # Secondary vault - 'cluster_id': 'YOUR_SECONDARY_CLUSTER_ID', # Cluster ID from your vault URL + 'vault_id': '', # Secondary vault + 'cluster_id': '', # Cluster ID from your vault URL 'env': Env.PROD, # Deployment environment - 'credentials': credentials + 'credentials': credentials, } # Step 3: Configure & Initialize Skyflow Client @@ -51,13 +52,10 @@ def perform_secure_data_deletion(): primary_table_name = '' # Replace with actual table name - primary_delete_request = DeleteRequest( - table=primary_table_name, - ids=primary_delete_ids - ) + primary_delete_request = DeleteRequest(table=primary_table_name, ids=primary_delete_ids) # Perform Delete Operation for Primary Vault - primary_delete_response = skyflow_client.vault('').delete(primary_delete_request) + primary_delete_response = skyflow_client.vault('').delete(primary_delete_request) # Handle Successful Response print('Primary Vault Deletion Successful:', primary_delete_response) @@ -67,10 +65,7 @@ def perform_secure_data_deletion(): secondary_table_name = '' # Replace with actual table name - secondary_delete_request = DeleteRequest( - table=secondary_table_name, - ids=secondary_delete_ids - ) + secondary_delete_request = DeleteRequest(table=secondary_table_name, ids=secondary_delete_ids) # Perform Delete Operation for Secondary Vault secondary_delete_response = skyflow_client.vault('').delete(secondary_delete_request) @@ -78,17 +73,12 @@ def perform_secure_data_deletion(): # Handle Successful Response print('Secondary Vault Deletion Successful:', secondary_delete_response) - except SkyflowError as error: # Comprehensive Error Handling - print('Skyflow Specific Error: ', { - 'code': error.http_code, - 'message': error.message, - 'details': error.details - }) + print('Skyflow Specific Error: ', {'code': error.http_code, 'message': error.message, 'details': error.details}) except Exception as error: print('Unexpected Error:', error) # Invoke the secure data deletion function -perform_secure_data_deletion() \ No newline at end of file +perform_secure_data_deletion() diff --git a/samples/vault_api/detokenize_records.py b/samples/vault_api/detokenize_records.py index e93d5a18..d0d10e0c 100644 --- a/samples/vault_api/detokenize_records.py +++ b/samples/vault_api/detokenize_records.py @@ -55,11 +55,11 @@ def perform_detokenization(): detokenize_data = [ { 'token': '', # Token to be detokenized - 'redaction': RedactionType.REDACTED + 'redaction_type': RedactionType.REDACTED }, { 'token': '', # Token to be detokenized - 'redaction': RedactionType.MASKED + 'redaction_type': RedactionType.MASKED } ] diff --git a/samples/vault_api/get_records.py b/samples/vault_api/get_records.py index b2fd445f..9e4d031a 100644 --- a/samples/vault_api/get_records.py +++ b/samples/vault_api/get_records.py @@ -4,6 +4,7 @@ from skyflow import Skyflow, LogLevel from skyflow.vault.data import GetRequest + def perform_secure_data_retrieval(): try: # Step 1: Configure Credentials @@ -28,7 +29,7 @@ def perform_secure_data_retrieval(): 'vault_id': '', # primary vault 'cluster_id': '', # Cluster ID from your vault URL 'env': Env.PROD, # Deployment environment (PROD by default) - 'credentials': credentials # Authentication method + 'credentials': credentials, # Authentication method } # Step 3: Configure & Initialize Skyflow Client @@ -42,10 +43,10 @@ def perform_secure_data_retrieval(): # Step 4: Prepare Retrieval Data - get_ids = ['', 'SKYFLOW_ID2'] + get_ids = ['', ''] get_request = GetRequest( - table='', # Replace with your actual table name + table='', # Replace with your actual table name ids=get_ids, ) @@ -57,15 +58,11 @@ def perform_secure_data_retrieval(): except SkyflowError as error: # Comprehensive Error Handling - print('Skyflow Specific Error: ', { - 'code': error.http_code, - 'message': error.message, - 'details': error.details - }) + print('Skyflow Specific Error: ', {'code': error.http_code, 'message': error.message, 'details': error.details}) except Exception as error: print('Unexpected Error:', error) # Invoke the secure data retrieval function -perform_secure_data_retrieval() \ No newline at end of file +perform_secure_data_retrieval() diff --git a/samples/vault_api/upload_file.py b/samples/vault_api/upload_file.py index df3e8cd0..7c762b4b 100644 --- a/samples/vault_api/upload_file.py +++ b/samples/vault_api/upload_file.py @@ -6,12 +6,16 @@ """ * Skyflow File Upload Example - * + * * This example demonstrates how to: * 1. Configure Skyflow client credentials * 2. Set up vault configuration - * 3. Create a file upload request - * 4. Handle response and errors + * 3. Upload a file to an existing record (with skyflow_id) + * 4. Upload a file and create a new record (without skyflow_id) + * 5. Handle response and errors + * + * Note: All FileUploadRequest parameters must be + * passed as keyword arguments. """ def perform_file_upload(): @@ -35,8 +39,8 @@ def perform_file_upload(): # Step 2: Configure Vault primary_vault_config = { - 'vault_id': '', - 'cluster_id': '', + 'vault_id': '', + 'cluster_id': '', 'env': Env.PROD, 'credentials': credentials } @@ -50,20 +54,28 @@ def perform_file_upload(): .build() ) - # Step 4: Prepare File Upload Data + # Step 4a: Upload a file to an existing record with open('', 'rb') as file_obj: - file_upload_request = FileUploadRequest( - table='', # Table to upload file to - column_name='', # Column to upload file into - file_object=file_obj, # Pass file object - skyflow_id='' # Record ID to associate the file with + upload_request = FileUploadRequest( + table='', + column_name='', + skyflow_id='', + file_object=file_obj ) - # Step 5: Perform File Upload - response = skyflow_client.vault('').upload_file(file_upload_request) + response = skyflow_client.vault('').upload_file(upload_request) + print('File upload to existing record:', response) - # Handle Successful Response - print('File upload successful: ', response) + # Step 4b: Upload a file and create a new record (omit skyflow_id) + with open('', 'rb') as file_obj: + upload_request = FileUploadRequest( + table='', + column_name='', + file_object=file_obj + ) + + response = skyflow_client.vault('').upload_file(upload_request) + print('File upload with new record:', response) except SkyflowError as error: print('Skyflow Specific Error: ', { diff --git a/setup.py b/setup.py index 09f844d2..d356b664 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ if sys.version_info < (3, 8): raise RuntimeError("skyflow requires Python 3.8+") -current_version = '2.0.0' +current_version = '2.0.2.dev0+961116e' setup( name='skyflow', @@ -21,8 +21,8 @@ long_description=open('README.rst').read(), install_requires=[ 'python_dateutil >= 2.5.3', - 'setuptools >= 21.0.0', - 'urllib3 >= 1.25.3, < 2.1.0', + 'setuptools >= 75.3.3', + 'urllib3 >= 1.25.3, <= 2.6.3', 'pydantic >= 2', 'typing-extensions >= 4.7.1', 'DateTime~=5.5', diff --git a/skyflow/client/skyflow.py b/skyflow/client/skyflow.py index 9f0d9dbf..ebd5ef7d 100644 --- a/skyflow/client/skyflow.py +++ b/skyflow/client/skyflow.py @@ -2,7 +2,8 @@ from skyflow import LogLevel from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages -from skyflow.utils.logger import log_info, Logger +from skyflow.utils.logger import log_info, log_warn, set_active_log_level, Logger +from skyflow.utils.constants import OptionField from skyflow.utils.validations import validate_vault_config, validate_connection_config, validate_update_vault_config, \ validate_update_connection_config, validate_credentials, validate_log_level from skyflow.vault.client.client import VaultClient @@ -30,7 +31,7 @@ def update_vault_config(self,config): self.__builder.update_vault_config(config) def get_vault_config(self, vault_id): - return self.__builder.get_vault_config(vault_id).get("vault_client").get_config() + return self.__builder.get_vault_config(vault_id).get(OptionField.VAULT_CLIENT).get_config() def add_connection_config(self, config): self.__builder._Builder__add_connection_config(config) @@ -45,7 +46,7 @@ def update_connection_config(self, config): return self def get_connection_config(self, connection_id): - return self.__builder.get_connection_config(connection_id).get("vault_client").get_config() + return self.__builder.get_connection_config(connection_id).get(OptionField.VAULT_CLIENT).get_config() def add_skyflow_credentials(self, credentials): self.__builder._Builder__add_skyflow_credentials(credentials) @@ -58,23 +59,25 @@ def set_log_level(self, log_level): self.__builder._Builder__set_log_level(log_level) return self + def update_log_level(self, log_level): + """.. deprecated:: Use set_log_level() instead. Will be removed in a future release.""" + log_warn(SkyflowMessages.Warning.UPDATE_LOG_LEVEL_DEPRECATED.value) + return self.set_log_level(log_level) + def get_log_level(self): return self.__builder._Builder__log_level - def update_log_level(self, log_level): - self.__builder._Builder__set_log_level(log_level) - def vault(self, vault_id = None) -> Vault: vault_config = self.__builder.get_vault_config(vault_id) - return vault_config.get("vault_controller") + return vault_config.get(OptionField.VAULT_CONTROLLER) def connection(self, connection_id = None) -> Connection: connection_config = self.__builder.get_connection_config(connection_id) - return connection_config.get("controller") + return connection_config.get(OptionField.CONTROLLER) def detect(self, vault_id = None) -> Detect: vault_config = self.__builder.get_vault_config(vault_id) - return vault_config.get("detect_controller") + return vault_config.get(OptionField.DETECT_CONTROLLER) class Builder: def __init__(self): @@ -87,13 +90,13 @@ def __init__(self): self.__logger = Logger(LogLevel.ERROR) def add_vault_config(self, config): - vault_id = config.get("vault_id") + vault_id = config.get(OptionField.VAULT_ID) if not isinstance(vault_id, str) or not vault_id: raise SkyflowError( SkyflowMessages.Error.INVALID_VAULT_ID.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value ) - if vault_id in [vault.get("vault_id") for vault in self.__vault_list]: + if vault_id in [vault.get(OptionField.VAULT_ID) for vault in self.__vault_list]: log_info(SkyflowMessages.Info.VAULT_CONFIG_EXISTS.value.format(vault_id), self.__logger) raise SkyflowError( SkyflowMessages.Error.VAULT_ID_ALREADY_EXISTS.value.format(vault_id), @@ -112,9 +115,11 @@ def remove_vault_config(self, vault_id): def update_vault_config(self, config): validate_update_vault_config(self.__logger, config) - vault_id = config.get("vault_id") + vault_id = config.get(OptionField.VAULT_ID) + if vault_id not in self.__vault_configs: + raise SkyflowError(SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format(vault_id), SkyflowMessages.ErrorCodes.INVALID_INPUT.value) vault_config = self.__vault_configs[vault_id] - vault_config.get("vault_client").update_config(config) + vault_config.get(OptionField.VAULT_CLIENT).update_config(config) def get_vault_config(self, vault_id): if vault_id is None: @@ -129,13 +134,13 @@ def get_vault_config(self, vault_id): def add_connection_config(self, config): - connection_id = config.get("connection_id") + connection_id = config.get(OptionField.CONNECTION_ID) if not isinstance(connection_id, str) or not connection_id: raise SkyflowError( SkyflowMessages.Error.INVALID_CONNECTION_ID.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value ) - if connection_id in [connection.get("connection_id") for connection in self.__connection_list]: + if connection_id in [connection.get(OptionField.CONNECTION_ID) for connection in self.__connection_list]: log_info(SkyflowMessages.Info.CONNECTION_CONFIG_EXISTS.value.format(connection_id), self.__logger) raise SkyflowError( SkyflowMessages.Error.CONNECTION_ID_ALREADY_EXISTS.value.format(connection_id), @@ -153,9 +158,11 @@ def remove_connection_config(self, connection_id): def update_connection_config(self, config): validate_update_connection_config(self.__logger, config) - connection_id = config['connection_id'] + connection_id = config[OptionField.CONNECTION_ID] + if connection_id not in self.__connection_configs: + raise SkyflowError(SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format(connection_id), SkyflowMessages.ErrorCodes.INVALID_INPUT.value) connection_config = self.__connection_configs[connection_id] - connection_config.get("vault_client").update_config(config) + connection_config.get(OptionField.VAULT_CLIENT).update_config(config) def get_connection_config(self, connection_id): if connection_id is None: @@ -183,37 +190,38 @@ def get_logger(self): def __add_vault_config(self, config): validate_vault_config(self.__logger, config) - vault_id = config.get("vault_id") + vault_id = config.get(OptionField.VAULT_ID) vault_client = VaultClient(config) self.__vault_configs[vault_id] = { - "vault_client": vault_client, - "vault_controller": Vault(vault_client), - "detect_controller": Detect(vault_client) + OptionField.VAULT_CLIENT: vault_client, + OptionField.VAULT_CONTROLLER: Vault(vault_client), + OptionField.DETECT_CONTROLLER: Detect(vault_client) } - log_info(SkyflowMessages.Info.VAULT_CONTROLLER_INITIALIZED.value.format(config.get("vault_id")), self.__logger) - log_info(SkyflowMessages.Info.DETECT_CONTROLLER_INITIALIZED.value.format(config.get("vault_id")), self.__logger) + log_info(SkyflowMessages.Info.VAULT_CONTROLLER_INITIALIZED.value.format(config.get(OptionField.VAULT_ID)), self.__logger) + log_info(SkyflowMessages.Info.DETECT_CONTROLLER_INITIALIZED.value.format(config.get(OptionField.VAULT_ID)), self.__logger) def __add_connection_config(self, config): validate_connection_config(self.__logger, config) - connection_id = config.get("connection_id") + connection_id = config.get(OptionField.CONNECTION_ID) vault_client = VaultClient(config) self.__connection_configs[connection_id] = { - "vault_client": vault_client, - "controller": Connection(vault_client) + OptionField.VAULT_CLIENT: vault_client, + OptionField.CONTROLLER: Connection(vault_client) } - log_info(SkyflowMessages.Info.CONNECTION_CONTROLLER_INITIALIZED.value.format(config.get("connection_id")), self.__logger) + log_info(SkyflowMessages.Info.CONNECTION_CONTROLLER_INITIALIZED.value.format(config.get(OptionField.CONNECTION_ID)), self.__logger) def __update_vault_client_logger(self, log_level, logger): for vault_id, vault_config in self.__vault_configs.items(): - vault_config.get("vault_client").set_logger(log_level,logger) + vault_config.get(OptionField.VAULT_CLIENT).set_logger(log_level,logger) for connection_id, connection_config in self.__connection_configs.items(): - connection_config.get("vault_client").set_logger(log_level,logger) + connection_config.get(OptionField.VAULT_CLIENT).set_logger(log_level,logger) def __set_log_level(self, log_level): validate_log_level(self.__logger, log_level) self.__log_level = log_level self.__logger.set_log_level(log_level) + set_active_log_level(log_level) self.__update_vault_client_logger(log_level, self.__logger) log_info(SkyflowMessages.Info.LOGGER_SETUP_DONE.value, self.__logger) log_info(SkyflowMessages.Info.CURRENT_LOG_LEVEL.value.format(self.__log_level), self.__logger) @@ -223,13 +231,14 @@ def __add_skyflow_credentials(self, credentials): self.__skyflow_credentials = credentials validate_credentials(self.__logger, credentials) for vault_id, vault_config in self.__vault_configs.items(): - vault_config.get("vault_client").set_common_skyflow_credentials(credentials) + vault_config.get(OptionField.VAULT_CLIENT).set_common_skyflow_credentials(credentials) for connection_id, connection_config in self.__connection_configs.items(): - connection_config.get("vault_client").set_common_skyflow_credentials(self.__skyflow_credentials) + connection_config.get(OptionField.VAULT_CLIENT).set_common_skyflow_credentials(self.__skyflow_credentials) def build(self): validate_log_level(self.__logger, self.__log_level) self.__logger.set_log_level(self.__log_level) + set_active_log_level(self.__log_level) for config in self.__vault_list: self.__add_vault_config(config) diff --git a/skyflow/error/_skyflow_error.py b/skyflow/error/_skyflow_error.py index 7b917fae..bf472177 100644 --- a/skyflow/error/_skyflow_error.py +++ b/skyflow/error/_skyflow_error.py @@ -1,5 +1,4 @@ from skyflow.utils import SkyflowMessages -from skyflow.utils.logger import log_error class SkyflowError(Exception): def __init__(self, @@ -8,12 +7,11 @@ def __init__(self, request_id = None, grpc_code = None, http_status = None, - details = []): + details = None): self.message = message self.http_code = http_code self.grpc_code = grpc_code self.http_status = http_status if http_status else SkyflowMessages.HttpStatus.BAD_REQUEST.value - self.details = details + self.details = details if details else None self.request_id = request_id - log_error(message, http_code, request_id, grpc_code, http_status, details) - super().__init__() \ No newline at end of file + super().__init__(message) \ No newline at end of file diff --git a/skyflow/service_account/_utils.py b/skyflow/service_account/_utils.py index 715716d8..deccf973 100644 --- a/skyflow/service_account/_utils.py +++ b/skyflow/service_account/_utils.py @@ -1,24 +1,74 @@ import json import datetime +import re import time import jwt +from urllib.parse import urlparse from skyflow.error import SkyflowError from skyflow.service_account.client.auth_client import AuthClient from skyflow.utils.logger import log_info, log_error_log from skyflow.utils import get_base_url, format_scope, SkyflowMessages +from skyflow.utils.constants import JWT, CredentialField, JwtField, OptionField, ResponseField +from skyflow.generated.rest.errors.unauthorized_error import UnauthorizedError +from skyflow.utils import is_valid_url +from skyflow.utils.constants import CTX_KEY_REGEX invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value +_CTX_KEY_PATTERN = re.compile(CTX_KEY_REGEX) + +_SNAKE_TO_CAMEL_CRED_MAP = { + 'private_key': CredentialField.PRIVATE_KEY, + 'client_id': CredentialField.CLIENT_ID, + 'key_id': CredentialField.KEY_ID, + 'token_uri': CredentialField.TOKEN_URI, + 'client_name': CredentialField.CLIENT_NAME, +} + + +def _normalize_credentials(credentials): + return {_SNAKE_TO_CAMEL_CRED_MAP.get(k, k): v for k, v in credentials.items()} + + +def _validate_and_resolve_ctx(ctx): + """Validate ctx value and return resolved value for JWT claims. + Returns None if ctx should be omitted, the value if valid, or raises SkyflowError if invalid. + """ + if ctx is None: + return None + if isinstance(ctx, str): + if ctx.strip() == '': + return None + return ctx + if isinstance(ctx, dict): + if len(ctx) == 0: + return None + for key in ctx: + if not isinstance(key, str) or not _CTX_KEY_PATTERN.match(key): + raise SkyflowError( + SkyflowMessages.Error.INVALID_CTX_MAP_KEY.value.format(key), + invalid_input_error_code + ) + return ctx + if isinstance(ctx, (bool, int, float)): + return ctx + raise SkyflowError( + SkyflowMessages.Error.INVALID_CTX_TYPE.value, + invalid_input_error_code + ) + def is_expired(token, logger = None): + if token is None: + return True if len(token) == 0: log_error_log(SkyflowMessages.ErrorLogs.INVALID_BEARER_TOKEN.value) return True try: decoded = jwt.decode( - token, options={"verify_signature": False, "verify_aud": False}) - if time.time() >= decoded['exp']: + token, options={OptionField.VERIFY_SIGNATURE: False, OptionField.VERIFY_AUD: False}) + if time.time() >= decoded[JwtField.EXP]: log_info(SkyflowMessages.Info.BEARER_TOKEN_EXPIRED.value, logger) log_error_log(SkyflowMessages.ErrorLogs.INVALID_BEARER_TOKEN.value) return True @@ -30,20 +80,18 @@ def is_expired(token, logger = None): return True def generate_bearer_token(credentials_file_path, options = None, logger = None): + log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_TRIGGERED.value, logger) try: - log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_TRIGGERED.value, logger) - credentials_file =open(credentials_file_path, 'r') + with open(credentials_file_path, 'r') as credentials_file: + try: + credentials = json.load(credentials_file) + except Exception: + log_error_log(SkyflowMessages.ErrorLogs.INVALID_CREDENTIALS_FILE.value, logger=logger) + raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), invalid_input_error_code) + except SkyflowError: + raise except Exception: raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code) - - try: - credentials = json.load(credentials_file) - except Exception: - log_error_log(SkyflowMessages.ErrorLogs.INVALID_CREDENTIALS_FILE.value, logger = logger) - raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), invalid_input_error_code) - - finally: - credentials_file.close() result = get_service_account_token(credentials, options, logger) return result @@ -58,26 +106,37 @@ def generate_bearer_token_from_creds(credentials, options = None, logger = None) return result def get_service_account_token(credentials, options, logger): + credentials = _normalize_credentials(credentials) try: - private_key = credentials["privateKey"] - except: - log_error_log(SkyflowMessages.ErrorLogs.PRIVATE_KEY_IS_REQUIRED.value, logger = logger) + private_key = credentials[CredentialField.PRIVATE_KEY] + except KeyError: + log_error_log(SkyflowMessages.ErrorLogs.PRIVATE_KEY_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_PRIVATE_KEY.value, invalid_input_error_code) try: - client_id = credentials["clientID"] - except: + client_id = credentials[CredentialField.CLIENT_ID] + except KeyError: log_error_log(SkyflowMessages.ErrorLogs.CLIENT_ID_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_CLIENT_ID.value, invalid_input_error_code) try: - key_id = credentials["keyID"] - except: + key_id = credentials[CredentialField.KEY_ID] + except KeyError: log_error_log(SkyflowMessages.ErrorLogs.KEY_ID_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_KEY_ID.value, invalid_input_error_code) try: - token_uri = credentials["tokenURI"] - except: + token_uri = credentials[CredentialField.TOKEN_URI] + except KeyError: log_error_log(SkyflowMessages.ErrorLogs.TOKEN_URI_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_TOKEN_URI.value, invalid_input_error_code) + + if not isinstance(token_uri, str) or not is_valid_url(token_uri): + log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value, logger=logger) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code) + + if options and CredentialField.TOKEN_URI_OPTION in options: + token_uri = options[CredentialField.TOKEN_URI_OPTION] + if not isinstance(token_uri, str) or not is_valid_url(token_uri): + log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value, logger=logger) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code) signed_token = get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger) base_url = get_base_url(token_uri) @@ -85,77 +144,92 @@ def get_service_account_token(credentials, options, logger): auth_api = auth_client.get_auth_api() formatted_scope = None - if options and "role_ids" in options: - formatted_scope = format_scope(options.get("role_ids")) + if options and OptionField.ROLE_IDS in options: + formatted_scope = format_scope(options.get(OptionField.ROLE_IDS)) - response = auth_api.authentication_service_get_auth_token(assertion = signed_token, - grant_type="urn:ietf:params:oauth:grant-type:jwt-bearer", + try: + response = auth_api.authentication_service_get_auth_token(assertion = signed_token, + grant_type=JWT.GRANT_TYPE_JWT_BEARER, scope=formatted_scope) - log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_SUCCESS.value, logger) + log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_SUCCESS.value, logger) + except UnauthorizedError: + log_error_log(SkyflowMessages.ErrorLogs.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value, logger=logger) + raise SkyflowError(SkyflowMessages.Error.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value, invalid_input_error_code) + except Exception: + log_error_log(SkyflowMessages.ErrorLogs.FAILED_TO_GET_BEARER_TOKEN.value, logger=logger) + raise SkyflowError(SkyflowMessages.Error.FAILED_TO_GET_BEARER_TOKEN.value, invalid_input_error_code) return response.access_token, response.token_type def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger): payload = { - "iss": client_id, - "key": key_id, - "aud": token_uri, - "sub": client_id, - "exp": datetime.datetime.utcnow() + datetime.timedelta(minutes=60) + JwtField.ISS: client_id, + JwtField.KEY: key_id, + JwtField.AUD: token_uri, + JwtField.SUB: client_id, + JwtField.EXP: datetime.datetime.utcnow() + datetime.timedelta(minutes=60) } - if options and "ctx" in options: - payload["ctx"] = options.get("ctx") + if options and OptionField.CTX in options: + resolved_ctx = _validate_and_resolve_ctx(options.get(OptionField.CTX)) + if resolved_ctx is not None: + payload[JwtField.CTX] = resolved_ctx try: - return jwt.encode(payload=payload, key=private_key, algorithm="RS256") + return jwt.encode(payload=payload, key=private_key, algorithm=JWT.ALGORITHM_RS256) except Exception: raise SkyflowError(SkyflowMessages.Error.JWT_INVALID_FORMAT.value, invalid_input_error_code) def get_signed_tokens(credentials_obj, options): - try: - expiry_time = int(time.time()) + options.get("time_to_live", 60) - prefix = "signed_token_" - - if options and options.get("data_tokens"): - for token in options["data_tokens"]: - claims = { - "iss": "sdk", - "key": credentials_obj.get("keyID"), - "exp": expiry_time, - "sub": credentials_obj.get("clientID"), - "tok": token, - "iat": int(time.time()), - } - - if "ctx" in options: - claims["ctx"] = options["ctx"] - - private_key = credentials_obj.get("privateKey") - signed_jwt = jwt.encode(claims, private_key, algorithm="RS256") - response_object = get_signed_data_token_response_object(prefix + signed_jwt, token) - log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKEN_SUCCESS.value) - return response_object - - except Exception: - raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) + options = options if options is not None else {} + credentials_obj = _normalize_credentials(credentials_obj) + expiry_time = int(time.time()) + options.get(OptionField.TIME_TO_LIVE, 60) + prefix = JWT.SIGNED_TOKEN_PREFIX + + token_uri = credentials_obj.get(CredentialField.TOKEN_URI) + if not isinstance(token_uri, str) or not is_valid_url(token_uri): + log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code) + + resolved_ctx = None + if OptionField.CTX in options: + resolved_ctx = _validate_and_resolve_ctx(options[OptionField.CTX]) + + results = [] + if options and options.get(OptionField.DATA_TOKENS): + for token in options[OptionField.DATA_TOKENS]: + claims = { + JwtField.ISS: JWT.ISSUER_SDK, + JwtField.KEY: credentials_obj.get(CredentialField.KEY_ID), + JwtField.EXP: expiry_time, + JwtField.SUB: credentials_obj.get(CredentialField.CLIENT_ID), + JwtField.TOK: token, + JwtField.IAT: int(time.time()), + } + if resolved_ctx is not None: + claims[JwtField.CTX] = resolved_ctx + private_key = credentials_obj.get(CredentialField.PRIVATE_KEY) + try: + signed_jwt = jwt.encode(claims, private_key, algorithm=JWT.ALGORITHM_RS256) + except Exception: + raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) + results.append(get_signed_data_token_response_object(prefix + signed_jwt, token)) + log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKEN_SUCCESS.value) + return results def generate_signed_data_tokens(credentials_file_path, options): log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKENS_TRIGGERED.value) try: - credentials_file =open(credentials_file_path, 'r') + with open(credentials_file_path, 'r') as credentials_file: + try: + credentials = json.load(credentials_file) + except Exception: + raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), + invalid_input_error_code) + except SkyflowError: + raise except Exception: raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code) - - try: - credentials = json.load(credentials_file) - except Exception: - raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), - invalid_input_error_code) - - finally: - credentials_file.close() - return get_signed_tokens(credentials, options) def generate_signed_data_tokens_from_creds(credentials, options): @@ -168,9 +242,6 @@ def generate_signed_data_tokens_from_creds(credentials, options): raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value, invalid_input_error_code) return get_signed_tokens(json_credentials, options) + def get_signed_data_token_response_object(signed_token, actual_token): - response_object = { - "token": actual_token, - "signed_token": signed_token - } - return response_object.get("token"), response_object.get("signed_token") + return actual_token, signed_token diff --git a/skyflow/utils/__init__.py b/skyflow/utils/__init__.py index f2788b11..664cf65d 100644 --- a/skyflow/utils/__init__.py +++ b/skyflow/utils/__init__.py @@ -1,5 +1,5 @@ from ..utils.enums import LogLevel, Env, TokenType from ._skyflow_messages import SkyflowMessages from ._version import SDK_VERSION -from ._helpers import get_base_url, format_scope +from ._helpers import get_base_url, format_scope, is_valid_url from ._utils import get_credentials, get_vault_url, construct_invoke_connection_request, get_metrics, parse_insert_response, handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_get_response, parse_invoke_connection_response, validate_api_key, encode_column_values, parse_deidentify_text_response, parse_reidentify_text_response, convert_detected_entity_to_entity_info diff --git a/skyflow/utils/_helpers.py b/skyflow/utils/_helpers.py index 97eecabc..12ff1257 100644 --- a/skyflow/utils/_helpers.py +++ b/skyflow/utils/_helpers.py @@ -8,4 +8,11 @@ def get_base_url(url): def format_scope(scopes): if not scopes: return None - return " ".join([f"role:{scope}" for scope in scopes]) \ No newline at end of file + return " ".join([f"role:{scope}" for scope in scopes]) + +def is_valid_url(url): + try: + result = urlparse(url) + return all([result.scheme == "https", result.netloc]) + except Exception: + return False \ No newline at end of file diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py index 3672cfa8..232bd8b0 100644 --- a/skyflow/utils/_skyflow_messages.py +++ b/skyflow/utils/_skyflow_messages.py @@ -4,6 +4,7 @@ error_prefix = f"Skyflow Python SDK {SDK_VERSION}" INFO = "INFO" +WARN = "WARN" ERROR = "ERROR" class SkyflowMessages: @@ -16,7 +17,7 @@ class ErrorCodes(Enum): REDACTION_WITH_TOKENS_NOT_SUPPORTED = 400 class Error(Enum): - GENERIC_API_ERROR = f"{error_prefix} Validation error. Invalid configuration. Please add a valid vault configuration." + GENERIC_API_ERROR = f"{error_prefix} API error. Error occurred." EMPTY_VAULT_ID = f"{error_prefix} Initialization failed. Invalid vault Id. Specify a valid vault Id." INVALID_VAULT_ID = f"{error_prefix} Initialization failed. Invalid vault Id. Specify a valid vault Id as a string." @@ -42,12 +43,13 @@ class Error(Enum): EMPTY_CREDENTIAL_FILE_PATH_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid credentials for {{}} with id {{}}. Specify a valid file path." EMPTY_CREDENTIAL_FILE_PATH = f"{error_prefix} Initialization failed. Invalid credentials. Specify a valid file path." INVALID_CREDENTIAL_FILE_PATH_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid credentials for {{}} with id {{}}. Expected file path to be a string." - INVALID_CREDENTIAL_FILE_PATH = f"{error_prefix} Initialization failed. Invalid credentials. Expected file path to be a string." + INVALID_CREDENTIAL_FILE_PATH = f"{error_prefix} Initialization failed. Invalid credentials. Expected file path to be a valid file path." EMPTY_CREDENTIALS_TOKEN_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid token for {{}} with id {{}}.Specify a valid credentials token." EMPTY_CREDENTIALS_TOKEN = f"{error_prefix} Initialization failed. Invalid token.Specify a valid credentials token." INVALID_CREDENTIALS_TOKEN_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid credentials token for {{}} with id {{}}. Expected token to be a string." INVALID_CREDENTIALS_TOKEN = f"{error_prefix} Initialization failed. Invalid credentials token. Expected token to be a string." - EXPIRED_TOKEN = f"${error_prefix} Initialization failed. Given token is expired. Specify a valid credentials token." + EXPIRED_BEARER_TOKEN = f"{error_prefix} Initialization failed. Bearer token is invalid or expired." + EXPIRED_TOKEN = f"{error_prefix} Initialization failed. Given token is expired. Specify a valid credentials token." EMPTY_API_KEY_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid api key for {{}} with id {{}}.Specify a valid api key." EMPTY_API_KEY= f"{error_prefix} Initialization failed. Invalid api key.Specify a valid api key." INVALID_API_KEY_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid api key for {{}} with id {{}}. Expected api key to be a string." @@ -60,6 +62,8 @@ class Error(Enum): EMPTY_CONTEXT = f"{error_prefix} Initialization failed. Invalid context provided. Specify context as type Context." INVALID_CONTEXT_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid context for {{}} with id {{}}. Specify a valid context." INVALID_CONTEXT = f"{error_prefix} Initialization failed. Invalid context. Specify a valid context." + INVALID_CTX_TYPE = f"{error_prefix} Initialization failed. Invalid ctx type. Specify ctx as a string or a dict." + INVALID_CTX_MAP_KEY = f"{error_prefix} Initialization failed. Invalid key '{{}}' in ctx dict. Keys must contain only alphanumeric characters and underscores." INVALID_LOG_LEVEL = f"{error_prefix} Initialization failed. Invalid log level. Specify a valid log level." EMPTY_LOG_LEVEL = f"{error_prefix} Initialization failed. Specify a valid log level." @@ -71,6 +75,9 @@ class Error(Enum): RESPONSE_NOT_JSON = f"{error_prefix} Response {{}} is not valid JSON." API_ERROR = f"{error_prefix} Server returned status code {{}}" + INVALID_JSON_RESPONSE = f"{error_prefix} Invalid JSON response received." + UNKNOWN_ERROR_DEFAULT_MESSAGE = f"{error_prefix} An unknown error occurred." + INVALID_FILE_INPUT = f"{error_prefix} Validation error. Invalid file input. Specify a valid file input." INVALID_DETECT_ENTITIES_TYPE = f"{error_prefix} Validation error. Invalid type of detect entities. Specify detect entities as list of DetectEntities enum." INVALID_TYPE_FOR_DEFAULT_TOKEN_TYPE = f"{error_prefix} Validation error. Invalid type of default token type. Specify default token type as TokenType enum." @@ -84,14 +91,15 @@ class Error(Enum): INVALID_TABLE_NAME_IN_INSERT = f"{error_prefix} Validation error. Invalid table name in insert request. Specify a valid table name." INVALID_TYPE_OF_DATA_IN_INSERT = f"{error_prefix} Validation error. Invalid type of data in insert request. Specify data as a object array." EMPTY_DATA_IN_INSERT = f"{error_prefix} Validation error. Data array cannot be empty. Specify data in insert request." - INVALID_UPSERT_OPTIONS_TYPE = f"{error_prefix} Validation error. 'upsert' key cannot be empty in options. At least one object of table and column is required." + INVALID_UPSERT_OPTIONS_TYPE = f"{error_prefix} Validation error. Invalid 'upsert' value in options. Specify 'upsert' as a non-empty string containing the column name." INVALID_HOMOGENEOUS_TYPE = f"{error_prefix} Validation error. Invalid type of homogeneous. Specify homogeneous as a string." INVALID_TOKEN_MODE_TYPE = f"{error_prefix} Validation error. Invalid type of token mode. Specify token mode as a TokenMode enum." INVALID_RETURN_TOKENS_TYPE = f"{error_prefix} Validation error. Invalid type of return tokens. Specify return tokens as a boolean." INVALID_CONTINUE_ON_ERROR_TYPE = f"{error_prefix} Validation error. Invalid type of continue on error. Specify continue on error as a boolean." TOKENS_PASSED_FOR_TOKEN_MODE_DISABLE = f"{error_prefix} Validation error. 'token_mode' wasn't specified. Set 'token_mode' to 'ENABLE' to insert tokens." INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT = f"{error_prefix} Validation error. 'token_mode' is set to 'ENABLE_STRICT', but some fields are missing tokens. Specify tokens for all fields." - NO_TOKENS_IN_INSERT = f"{error_prefix} Validation error. Tokens weren't specified for records while 'token_strict' was {{}}. Specify tokens." + MISMATCH_OF_FIELDS_AND_TOKENS = f"{error_prefix} Validation error. Keys for values and tokens are not matching. Ensure each values entry and its corresponding tokens entry have the same keys." + NO_TOKENS_IN_INSERT = f"{error_prefix} Validation error. Tokens weren't specified for records while 'token_mode' was {{}}. Specify tokens." BATCH_INSERT_FAILURE = f"{error_prefix} Insert operation failed." GET_FAILURE = f"{error_prefix} Get operation failed." HOMOGENOUS_NOT_SUPPORTED_WITH_UPSERT = f"{error_prefix} Validation error. Homogenous is not supported when upsert is passed." @@ -114,15 +122,16 @@ class Error(Enum): INVOKE_CONNECTION_FAILED = f"{error_prefix} Invoke Connection operation failed." INVALID_IDS_TYPE = f"{error_prefix} Validation error. 'ids' has a value of type {{}}. Specify 'ids' as list." - INVALID_REDACTION_TYPE = f"{error_prefix} Validation error. 'redaction' has a value of type {{}}. Specify 'redaction' as type Skyflow.RedactionType." - INVALID_COLUMN_NAME = f"{error_prefix} Validation error. 'column' has a value of type {{}}. Specify 'column' as a string." - INVALID_COLUMN_VALUE = f"{error_prefix} Validation error. columnValues key has a value of type {{}}. Specify columnValues key as list." + INVALID_REDACTION_TYPE = f"{error_prefix} Validation error. 'redaction_type' has a value of type {{}}. Specify 'redaction_type' as type Skyflow.RedactionType." + INVALID_COLUMN_NAME = f"{error_prefix} Validation error. column_name has a value of type {{}}. Specify 'column' as a string." + INVALID_COLUMN_VALUE = f"{error_prefix} Validation error. column_values key has a value of type {{}}. Specify column_values key as list." + INVALID_COLUMN_VALUES = f"{error_prefix} Validation error. column_values key is an empty list. Specify at least one column value when column_name is passed." INVALID_FIELDS_VALUE = f"{error_prefix} Validation error. fields key has a value of type{{}}. Specify fields key as list." - BOTH_OFFSET_AND_LIMIT_SPECIFIED = f"${error_prefix} Validation error. Both offset and limit cannot be present at the same time" + BOTH_OFFSET_AND_LIMIT_SPECIFIED = f"{error_prefix} Validation error. Both offset and limit cannot be present at the same time" INVALID_OFF_SET_VALUE = f"{error_prefix} Validation error. offset key has a value of type {{}}. Specify offset key as integer." INVALID_LIMIT_VALUE = f"{error_prefix} Validation error. limit key has a value of type {{}}. Specify limit key as integer." INVALID_DOWNLOAD_URL_VALUE = f"{error_prefix} Validation error. download_url key has a value of type {{}}. Specify download_url key as boolean." - REDACTION_WITH_TOKENS_NOT_SUPPORTED = f"{error_prefix} Validation error. 'redaction' can't be used when tokens are specified. Remove 'redaction' from payload if tokens are specified." + REDACTION_WITH_TOKENS_NOT_SUPPORTED = f"{error_prefix} Validation error. 'redaction_type' can't be used when tokens are specified. Remove 'redaction_type' from payload if tokens are specified." TOKENS_GET_COLUMN_NOT_SUPPORTED = f"{error_prefix} Validation error. Column name and/or column values can't be used when tokens are specified. Remove unique column values or tokens from the payload." BOTH_IDS_AND_COLUMN_DETAILS_SPECIFIED = f"{error_prefix} Validation error. Both Skyflow IDs and column details can't be specified. Either specify Skyflow IDs or unique column details." INVALID_ORDER_BY_VALUE = f"{error_prefix} Validation error. order_by key has a value of type {{}}. Specify order_by key as Skyflow.OrderBy" @@ -130,7 +139,7 @@ class Error(Enum): UPDATE_FIELD_KEY_ERROR = f"{error_prefix} Validation error. Fields are empty in an update payload. Specify at least one field." INVALID_FIELDS_TYPE = f"{error_prefix} Validation error. The 'data' key has a value of type {{}}. Specify 'data' as a dictionary." IDS_KEY_ERROR = f"{error_prefix} Validation error. 'ids' key is missing from the payload. Specify an 'ids' key." - INVALID_TOKENS_LIST_VALUE = f"{error_prefix} Validation error. The 'data' field is invalid. Specify 'data' as a list of dictionaries containing 'token' and 'redaction'." + INVALID_TOKENS_LIST_VALUE = f"{error_prefix} Validation error. The 'data' field is invalid. Specify 'data' as a list of dictionaries containing 'token' and 'redaction_type'." INVALID_DATA_FOR_DETOKENIZE = f"{error_prefix}" EMPTY_TOKENS_LIST_VALUE = f"{error_prefix} Validation error. Tokens are empty in detokenize payload. Specify at lease one token" INVALID_TOKEN_TYPE = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Tokens should be of type string." @@ -153,10 +162,13 @@ class Error(Enum): MISSING_CLIENT_ID = f"{error_prefix} Initialization failed. Unable to read client ID in credentials. Verify your client ID." MISSING_KEY_ID = f"{error_prefix} Initialization failed. Unable to read key ID in credentials. Verify your key ID." MISSING_TOKEN_URI = f"{error_prefix} Initialization failed. Unable to read token URI in credentials. Verify your token URI." + INVALID_TOKEN_URI = f"{error_prefix} Initialization failed. Invalid Skyflow credentials. The token URI must be a string and a valid URL." JWT_INVALID_FORMAT = f"{error_prefix} Initialization failed. Invalid private key format. Verify your credentials." JWT_DECODE_ERROR = f"{error_prefix} Validation error. Invalid access token. Verify your credentials." FILE_INVALID_JSON = f"{error_prefix} Initialization failed. File at {{}} is not in valid JSON format. Verify the file contents." INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV = f"{error_prefix} Validation error. Invalid JSON format in SKYFLOW_CREDENTIALS environment variable." + FAILED_TO_GET_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Failed to generate bearer token." + UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Authorization failed while retrieving the bearer token." INVALID_TEXT_IN_DEIDENTIFY= f"{error_prefix} Validation error. The text field is required and must be a non-empty string. Specify a valid text." INVALID_ENTITIES_IN_DEIDENTIFY= f"{error_prefix} Validation error. The entities field must be an array of DetectEntities enums. Specify a valid entities." @@ -278,7 +290,6 @@ class Info(Enum): VALIDATING_FILE_UPLOAD_REQUEST = f"{INFO}: [{error_prefix}] Validating file upload request." FILE_UPLOAD_REQUEST_RESOLVED = f"{INFO}: [{error_prefix}] File upload request resolved." FILE_UPLOAD_SUCCESS = f"{INFO}: [{error_prefix}] File uploaded successfully." - FILE_UPLOAD_REQUEST_REJECTED = f"{ERROR}: [{error_prefix}] File upload failed." INVOKE_CONNECTION_TRIGGERED = f"{INFO}: [{error_prefix}] Invoke connection method triggered." VALIDATING_INVOKE_CONNECTION_REQUEST = f"{INFO}: [{error_prefix}] Validating invoke connection request." @@ -308,6 +319,8 @@ class Info(Enum): DETECT_REQUEST_RESOLVED = f"{INFO}: [{error_prefix}] Detect request is resolved." class ErrorLogs(Enum): + INVALID_LOG_LEVEL = f"{ERROR}: [{error_prefix}] Invalid log level. Specify a valid log level." + INVALID_KEY = f"{ERROR}: [{error_prefix}] Invalid key {{}} in config." VAULTID_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Invalid vault config. Vault ID is required." EMPTY_VAULTID = f"{ERROR}: [{error_prefix}] Invalid vault config. Vault ID can not be empty." CLUSTER_ID_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Invalid vault config. Cluster ID is required." @@ -332,6 +345,8 @@ class ErrorLogs(Enum): KEY_ID_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Key ID is required." TOKEN_URI_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Token URI is required." INVALID_TOKEN_URI = f"{ERROR}: [{error_prefix}] Invalid value for token URI in credentials." + FAILED_TO_GET_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Failed to generate bearer token." + UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Authorization failed while retrieving the bearer token." TABLE_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Table is required." @@ -346,13 +361,14 @@ class ErrorLogs(Enum): EMPTY_OR_NULL_VALUE_IN_TOKENS = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Value can not be null or empty in tokens for key {{}}." EMPTY_OR_NULL_KEY_IN_TOKENS = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Key can not be null or empty in tokens." MISMATCH_OF_FIELDS_AND_TOKENS = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Keys for values and tokens are not matching." + FILE_UPLOAD_REQUEST_REJECTED = f"{ERROR}: [{error_prefix}] File upload failed." EMPTY_IDS = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Ids can not be empty." EMPTY_OR_NULL_ID_IN_IDS = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Id can not be null or empty in ids at index {{}}." TOKENIZATION_NOT_SUPPORTED_WITH_REDACTION= f"{ERROR}: [{error_prefix}] Invalid {{}} request. Tokenization is not supported when redaction is applied." TOKENIZATION_SUPPORTED_ONLY_WITH_IDS=f"{ERROR}: [{error_prefix}] Invalid {{}} request. Tokenization is not supported when column name and values are passed." - TOKENS_NOT_ALLOWED_WITH_BYOT_DISABLE = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Tokens are not allowed when token_strict is DISABLE." - INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT =f"{ERROR}: [{error_prefix}] Invalid {{}} request. For tokenStrict as ENABLE_STRICT, tokens should be passed for all fields." + TOKENS_NOT_ALLOWED_WITH_BYOT_DISABLE = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Tokens are not allowed when token_mode is DISABLE." + INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT =f"{ERROR}: [{error_prefix}] Invalid {{}} request. For token_mode as ENABLE_STRICT, tokens should be passed for all fields." TOKENS_REQUIRED = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Tokens are required." EMPTY_FIELDS = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Fields can not be empty." EMPTY_OFFSET = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Offset ca not be empty." @@ -363,7 +379,7 @@ class ErrorLogs(Enum): SKYFLOW_ID_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Skyflow Id is required." EMPTY_SKYFLOW_ID = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Skyflow Id can not be empty." - COLUMN_VALUES_IS_REQUIRED_TOKENIZE = f"{ERROR}: [{error_prefix}] Invalid {{}} request. ColumnValues are required." + COLUMN_VALUES_IS_REQUIRED_TOKENIZE = f"{ERROR}: [{error_prefix}] Invalid {{}} request. column_values are required." EMPTY_COLUMN_GROUP_IN_COLUMN_VALUES = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Column group can not be null or empty in column values at index %s2." EMPTY_QUERY= f"{ERROR}: [{error_prefix}] Invalid {{}} request. Query can not be empty." @@ -386,6 +402,7 @@ class ErrorLogs(Enum): SAVING_DEIDENTIFY_FILE_FAILED = f"{ERROR}: [{error_prefix}] Error while saving deidentified file to output directory." REIDENTIFY_TEXT_REQUEST_REJECTED = f"{ERROR}: [{error_prefix}] Reidentify text resulted in failure." DETECT_FILE_REQUEST_REJECTED = f"{ERROR}: [{error_prefix}] Deidentify file resulted in failure." + EMPTY_FILE_COLUMN_NAME = f"{ERROR}: [{error_prefix}] Empty column name in FILE_UPLOAD" class Interface(Enum): INSERT = "INSERT" @@ -400,7 +417,18 @@ class HttpStatus(Enum): BAD_REQUEST = "Bad Request" class Warning(Enum): - WARNING_MESSAGE = "WARNING MESSAGE" + DETOKENIZE_REDACTION_KEY_DEPRECATED = ( + f"{WARN}: [{error_prefix}] 'redaction' key in detokenize data is deprecated and will be removed in a future version. Use 'redaction_type' instead." + ) + UPDATE_LOG_LEVEL_DEPRECATED = ( + f"{WARN}: [{error_prefix}] Skyflow.update_log_level() is deprecated. " + "Use Skyflow.set_log_level() instead." + ) + FILE_UPLOAD_REQUEST_ARG_ORDER_DEPRECATED = ( + f"{WARN}: [{error_prefix}] FileUploadRequest: argument order changed. " + "Old positional order: (table, skyflow_id, column_name). " + "New order: FileUploadRequest(table, column_name=..., skyflow_id=...)." + ) diff --git a/skyflow/utils/_utils.py b/skyflow/utils/_utils.py index 4278357e..e3b8eea9 100644 --- a/skyflow/utils/_utils.py +++ b/skyflow/utils/_utils.py @@ -20,7 +20,9 @@ from skyflow.vault.detect import DeidentifyTextResponse, ReidentifyTextResponse from skyflow.vault.detect import EntityInfo, TextIndex from . import SkyflowMessages, SDK_VERSION -from .constants import PROTOCOL +from .constants import (PROTOCOL, HttpHeader, ApiKey, ContentType as ContentTypeConstants, + EncodingType, BooleanString, ResponseField, CredentialField, SdkPrefix, + SdkMetricsKey, ErrorDefaults, HttpStatusCode) from .enums import Env, ContentType, EnvUrls from skyflow.vault.data import InsertResponse, UpdateResponse, DeleteResponse, QueryResponse, GetResponse from .validations import validate_invoke_connection_params @@ -30,29 +32,21 @@ invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value def get_credentials(config_level_creds = None, common_skyflow_creds = None, logger = None): - dotenv.load_dotenv() + if config_level_creds is not None: + return config_level_creds + if common_skyflow_creds is not None: + return common_skyflow_creds dotenv_path = dotenv.find_dotenv(usecwd=True) if dotenv_path: load_dotenv(dotenv_path) env_skyflow_credentials = os.getenv("SKYFLOW_CREDENTIALS") - if config_level_creds: - return config_level_creds - if common_skyflow_creds: - return common_skyflow_creds if env_skyflow_credentials: - env_skyflow_credentials.strip() - try: - env_creds = env_skyflow_credentials.replace('\n', '\\n') - return { - 'credentials_string': env_creds - } - except json.JSONDecodeError: - raise SkyflowError(SkyflowMessages.Error.INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV.value, invalid_input_error_code) - else: - raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) + env_creds = env_skyflow_credentials.strip().replace('\n', '\\n') + return {'credentials_string': env_creds} + raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) def validate_api_key(api_key: str, logger = None) -> bool: - if len(api_key) != 42: + if len(api_key) != ApiKey.LENGTH: log_error_log(SkyflowMessages.ErrorLogs.INVALID_API_KEY.value, logger = logger) return False api_key_pattern = re.compile(r'^sky-[a-zA-Z0-9]{5}-[a-fA-F0-9]{32}$') @@ -78,9 +72,9 @@ def parse_path_params(url, path_params): return result -def to_lowercase_keys(dict): +def to_lowercase_keys(data): result = {} - for key, value in dict.items(): + for key, value in data.items(): result[key.lower()] = value return result @@ -104,31 +98,45 @@ def convert_detected_entity_to_entity_info(detected_entity): def construct_invoke_connection_request(request, connection_url, logger) -> PreparedRequest: url = parse_path_params(connection_url.rstrip('/'), request.path_params) - try: - if isinstance(request.headers, dict): - header = to_lowercase_keys(json.loads( - json.dumps(request.headers))) - else: - raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code) - except Exception: - raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code) + header = None + content_type = None - if not 'Content-Type'.lower() in header: - header['content-type'] = ContentType.JSON.value + if request.headers is not None: + try: + if isinstance(request.headers, dict): + header = to_lowercase_keys(json.loads( + json.dumps(request.headers))) + + content_type = header.get(HttpHeader.CONTENT_TYPE_LOWERCASE) + else: + raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code) + except SkyflowError: + raise + except Exception: + raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code) - try: - if isinstance(request.body, dict): - json_data, files = get_data_from_content_type( - request.body, header["content-type"] - ) - else: + json_data = None + files = {} + + if request.body is not None: + try: + if isinstance(request.body, dict): + json_data, files = get_data_from_content_type( + request.body, content_type + ) + else: + raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code) + except SkyflowError: + raise + except Exception: raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code) - except Exception as e: - raise SkyflowError( SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code) + + if files and header and content_type == ContentType.FORMDATA.value: + header.pop(HttpHeader.CONTENT_TYPE_LOWERCASE, None) validate_invoke_connection_params(logger, request.query_params, request.path_params) - if not hasattr(request.method, 'value'): + if not hasattr(request.method, ResponseField.VALUE): raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_METHOD.value, invalid_input_error_code) try: @@ -174,19 +182,59 @@ def render_key(parents): def get_data_from_content_type(data, content_type): converted_data = data files = {} + if content_type == ContentType.URLENCODED.value: converted_data = http_build_query(data) elif content_type == ContentType.FORMDATA.value: - converted_data = r_urlencode(list(), dict(), data) - files = {(None, None)} + converted_data = None + files = {} + for key, value in data.items(): + files[key] = (None, str(value)) elif content_type == ContentType.JSON.value: converted_data = json.dumps(data) + elif content_type == ContentType.XML.value or content_type == 'application/xml' or content_type == 'text/xml': + if isinstance(data, dict): + converted_data = dict_to_xml(data) + else: + converted_data = str(data) + elif content_type == ContentType.HTML.value or content_type == 'text/html': + if isinstance(data, dict): + converted_data = json.dumps(data) + else: + converted_data = str(data) + else: + if isinstance(data, dict): + converted_data = json.dumps(data) + else: + converted_data = str(data) return converted_data, files +def dict_to_xml(data, root_tag='root'): + def build_xml(d, tag='item'): + if isinstance(d, dict): + xml_parts = [f'<{tag}>'] + for key, value in d.items(): + xml_parts.append(build_xml(value, key)) + xml_parts.append(f'') + return ''.join(xml_parts) + elif isinstance(d, list): + return ''.join([build_xml(item, tag) for item in d]) + else: + return f'<{tag}>{d}' + + xml_parts = [f'<{root_tag}>'] + for key, value in data.items(): + xml_parts.append(build_xml(value, key)) + xml_parts.append(f'') + return ''.join(xml_parts) + + +_CACHED_METRICS: dict = {} def get_metrics(): - sdk_name_version = "skyflow-python@" + SDK_VERSION + if _CACHED_METRICS: + return _CACHED_METRICS try: sdk_client_device_model = platform.node() @@ -203,43 +251,43 @@ def get_metrics(): except Exception: sdk_runtime_details = "" - details_dic = { - 'sdk_name_version': sdk_name_version, - 'sdk_client_device_model': sdk_client_device_model, - 'sdk_client_os_details': sdk_client_os_details, - 'sdk_runtime_details': "Python " + sdk_runtime_details, - } - return details_dic + _CACHED_METRICS.update({ + SdkMetricsKey.SDK_NAME_VERSION: SdkPrefix.SKYFLOW_PYTHON + SDK_VERSION, + SdkMetricsKey.SDK_CLIENT_DEVICE_MODEL: sdk_client_device_model, + SdkMetricsKey.SDK_CLIENT_OS_DETAILS: sdk_client_os_details, + SdkMetricsKey.SDK_RUNTIME_DETAILS: SdkPrefix.PYTHON_RUNTIME + sdk_runtime_details, + }) + return _CACHED_METRICS def parse_insert_response(api_response, continue_on_error): # Retrieve the headers and data from the API response api_response_headers = api_response.headers api_response_data = api_response.data # Retrieve the request ID from the headers - request_id = api_response_headers.get('x-request-id') + request_id = api_response_headers.get(HttpHeader.X_REQUEST_ID) inserted_fields = [] errors = [] insert_response = InsertResponse() if continue_on_error: for idx, response in enumerate(api_response_data.responses): - if response['Status'] == 200: - body = response['Body'] - if 'records' in body: - for record in body['records']: + if response[ResponseField.STATUS] == HttpStatusCode.OK: + body = response[ResponseField.BODY] + if ResponseField.RECORDS in body: + for record in body[ResponseField.RECORDS]: inserted_field = { - 'skyflow_id': record['skyflow_id'], - 'request_index': idx + ResponseField.SKYFLOW_ID: record[ResponseField.SKYFLOW_ID], + ResponseField.REQUEST_INDEX: idx } - if 'tokens' in record: - inserted_field.update(record['tokens']) + if ResponseField.TOKENS in record: + inserted_field.update(record[ResponseField.TOKENS]) inserted_fields.append(inserted_field) - elif response['Status'] == 400: + elif response[ResponseField.STATUS] == HttpStatusCode.BAD_REQUEST: error = { - 'request_index': idx, - 'request_id': request_id, - 'error': response['Body']['error'], - 'http_code': response['Status'], + ResponseField.REQUEST_INDEX: idx, + ResponseField.REQUEST_ID: request_id, + ResponseField.ERROR: response[ResponseField.BODY][ResponseField.ERROR], + ResponseField.HTTP_CODE: response[ResponseField.STATUS], } errors.append(error) @@ -248,7 +296,7 @@ def parse_insert_response(api_response, continue_on_error): else: for record in api_response_data.records: field_data = { - 'skyflow_id': record.skyflow_id + ResponseField.SKYFLOW_ID: record.skyflow_id } if record.tokens: @@ -263,7 +311,7 @@ def parse_insert_response(api_response, continue_on_error): def parse_update_record_response(api_response: V1UpdateRecordResponse): update_response = UpdateResponse() updated_field = dict() - updated_field['skyflow_id'] = api_response.skyflow_id + updated_field[ResponseField.SKYFLOW_ID] = api_response.skyflow_id if api_response.tokens is not None: updated_field.update(api_response.tokens) @@ -293,23 +341,23 @@ def parse_detokenize_response(api_response: HttpResponse[V1DetokenizeResponse]): api_response_headers = api_response.headers api_response_data = api_response.data # Retrieve the request ID from the headers - request_id = api_response_headers.get('x-request-id') + request_id = api_response_headers.get(HttpHeader.X_REQUEST_ID) detokenized_fields = [] errors = [] for record in api_response_data.records: if record.error: errors.append({ - "token": record.token, - "error": record.error, - "request_id": request_id + ResponseField.TOKEN: record.token, + ResponseField.ERROR: record.error, + ResponseField.REQUEST_ID: request_id }) else: value_type = record.value_type if record.value_type else None detokenized_fields.append({ - "token": record.token, - "value": record.value, - "type": value_type + ResponseField.TOKEN: record.token, + ResponseField.VALUE: record.value, + ResponseField.TYPE: value_type }) detokenized_fields = detokenized_fields @@ -322,7 +370,7 @@ def parse_detokenize_response(api_response: HttpResponse[V1DetokenizeResponse]): def parse_tokenize_response(api_response: V1TokenizeResponse): tokenize_response = TokenizeResponse() - tokenized_fields = [{"token": record.token} for record in api_response.records] + tokenized_fields = [{ResponseField.TOKEN: record.token} for record in api_response.records] tokenize_response.tokenized_fields = tokenized_fields @@ -334,7 +382,7 @@ def parse_query_response(api_response: V1GetQueryResponse): for record in api_response.records: field_object = { **record.fields, - "tokenized_data": {} + ResponseField.TOKENIZED_DATA: {} } fields.append(field_object) query_response.fields = fields @@ -344,40 +392,59 @@ def parse_invoke_connection_response(api_response: requests.Response): status_code = api_response.status_code content = api_response.content if isinstance(content, bytes): - content = content.decode('utf-8') + content = content.decode(EncodingType.UTF_8) + try: api_response.raise_for_status() - try: - data = json.loads(content) - metadata = {} - if 'x-request-id' in api_response.headers: - metadata['request_id'] = api_response.headers['x-request-id'] - - return InvokeConnectionResponse(data=data, metadata=metadata, errors=None) - except Exception as e: - raise SkyflowError(SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format(content), status_code) + + content_type = api_response.headers.get(HttpHeader.CONTENT_TYPE_LOWERCASE, '').lower() + + if ContentTypeConstants.APPLICATION_JSON in content_type or not content_type: + try: + data = json.loads(content) + except json.JSONDecodeError: + data = content + else: + data = content + + metadata = {} + if HttpHeader.X_REQUEST_ID in api_response.headers: + metadata[ResponseField.REQUEST_ID] = api_response.headers[HttpHeader.X_REQUEST_ID] + + return InvokeConnectionResponse(data=data, metadata=metadata, errors=None) + except HTTPError: message = SkyflowMessages.Error.API_ERROR.value.format(status_code) + request_id = api_response.headers.get(HttpHeader.X_REQUEST_ID) + try: - error_response = json.loads(content) - request_id = api_response.headers['x-request-id'] - error_from_client = api_response.headers.get('error-from-client') - - status_code = error_response.get('error', {}).get('http_code', 500) # Default to 500 if not found - http_status = error_response.get('error', {}).get('http_status') - grpc_code = error_response.get('error', {}).get('grpc_code') - details = error_response.get('error', {}).get('details') - message = error_response.get('error', {}).get('message', "An unknown error occurred.") - + error_response = json.loads(content) + error_from_client = api_response.headers.get(HttpHeader.ERROR_FROM_CLIENT) + + http_status = None + grpc_code = None + details = None + + error_obj = error_response.get(ResponseField.ERROR) if isinstance(error_response, dict) else None + if isinstance(error_obj, dict): + status_code = error_obj.get(ResponseField.HTTP_CODE, status_code) + http_status = error_obj.get(ResponseField.HTTP_STATUS) + grpc_code = error_obj.get(ResponseField.GRPC_CODE) + details = error_obj.get(ResponseField.DETAILS) + message = error_obj.get(ResponseField.MESSAGE, message) + elif isinstance(error_obj, str) and error_obj: + message = error_obj + if error_from_client is not None: - if details is None: details = [] - error_from_client_bool = error_from_client.lower() == 'true' - details.append({'error_from_client': error_from_client_bool}) + if details is None: + details = [] + error_from_client_bool = error_from_client.lower() == BooleanString.TRUE + details.append({ResponseField.ERROR_FROM_CLIENT: error_from_client_bool}) raise SkyflowError(message, status_code, request_id, grpc_code, http_status, details) + except json.JSONDecodeError: - message = SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format(content) - raise SkyflowError(message, status_code) + raise SkyflowError(message, status_code, request_id) def parse_deidentify_text_response(api_response: DeidentifyStringResponse): entities = [convert_detected_entity_to_entity_info(entity) for entity in api_response.entities] @@ -395,51 +462,79 @@ def log_and_reject_error(description, status_code, request_id, http_status=None, raise SkyflowError(description, status_code, request_id, grpc_code, http_status, details) def handle_exception(error, logger): - # handle invalid cluster ID error scenario - if (isinstance(error, httpx.ConnectError)): - handle_generic_error(error, None, SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger) + if isinstance(error, httpx.ConnectError): + description = str(error) if error else SkyflowMessages.Error.GENERIC_API_ERROR.value + log_and_reject_error(description, SkyflowMessages.ErrorCodes.INVALID_INPUT.value, None, logger=logger) + return + + if not hasattr(error, 'headers') or not hasattr(error, 'body') or error.headers is None or error.body is None: + description = str(error) if error else SkyflowMessages.Error.GENERIC_API_ERROR.value + log_and_reject_error(description, SkyflowMessages.ErrorCodes.SERVER_ERROR.value, None, logger=logger) + return - request_id = error.headers.get('x-request-id', 'unknown-request-id') - content_type = error.headers.get('content-type') + request_id = error.headers.get(HttpHeader.X_REQUEST_ID, ErrorDefaults.UNKNOWN_REQUEST_ID) + content_type = error.headers.get(HttpHeader.CONTENT_TYPE_LOWERCASE) data = error.body if content_type: - if 'application/json' in content_type: + if ContentTypeConstants.APPLICATION_JSON in content_type: handle_json_error(error, data, request_id, logger) - elif 'text/plain' in content_type: + elif ContentTypeConstants.TEXT_PLAIN in content_type: handle_text_error(error, data, request_id, logger) else: - handle_generic_error(error, request_id, logger) + handle_generic_error_with_status(error, request_id, error.status, logger) else: - handle_generic_error(error, request_id, logger) + handle_generic_error_with_status(error, request_id, error.status, logger) def handle_json_error(err, data, request_id, logger): try: - if isinstance(data, dict): # If data is already a dict + if isinstance(data, dict): description = data elif isinstance(data, ErrorResponse): description = data.dict() else: description = json.loads(data) - status_code = description.get('error', {}).get('http_code', 500) # Default to 500 if not found - http_status = description.get('error', {}).get('http_status') - grpc_code = description.get('error', {}).get('grpc_code') - details = description.get('error', {}).get('details', []) - description_message = description.get('error', {}).get('message', "An unknown error occurred.") - log_and_reject_error(description_message, status_code, request_id, http_status, grpc_code, details, logger = logger) + if ResponseField.ERROR in description: + error_obj = description.get(ResponseField.ERROR, {}) + status_code = error_obj.get(ResponseField.HTTP_CODE, HttpStatusCode.INTERNAL_SERVER_ERROR) + http_status = error_obj.get(ResponseField.HTTP_STATUS) + grpc_code = error_obj.get(ResponseField.GRPC_CODE) + details = error_obj.get(ResponseField.DETAILS, []) + description_message = error_obj.get(ResponseField.MESSAGE, SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value) + elif ResponseField.RESPONSES in description: + responses = description.get(ResponseField.RESPONSES, []) + messages = [] + status_code = HttpStatusCode.INTERNAL_SERVER_ERROR + for resp in responses: + resp_status = resp.get(ResponseField.STATUS, HttpStatusCode.INTERNAL_SERVER_ERROR) + resp_body = resp.get(ResponseField.BODY, {}) + if isinstance(resp_status, int) and resp_status >= HttpStatusCode.BAD_REQUEST: + status_code = resp_status + error_msg = resp_body.get(ResponseField.ERROR) + if error_msg: + messages.append(str(error_msg)) + description_message = '; '.join(messages) if messages else SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value + http_status = None + grpc_code = None + details = [] + else: + status_code = HttpStatusCode.INTERNAL_SERVER_ERROR + http_status = None + grpc_code = None + details = [] + description_message = SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value + + log_and_reject_error(description_message, status_code, request_id, http_status, grpc_code, details, logger=logger) except json.JSONDecodeError: - log_and_reject_error("Invalid JSON response received.", err, request_id, logger = logger) + log_and_reject_error(SkyflowMessages.Error.INVALID_JSON_RESPONSE.value, err, request_id, logger=logger) def handle_text_error(err, data, request_id, logger): log_and_reject_error(data, err.status, request_id, logger = logger) -def handle_generic_error(err, request_id, logger): - handle_generic_error(err, request_id, err.status, logger = logger) - -def handle_generic_error(err, request_id, status, logger): - description = SkyflowMessages.Error.GENERIC_API_ERROR.value - log_and_reject_error(description, status, request_id, logger = logger) +def handle_generic_error_with_status(err, request_id, status, logger): + description = str(err) if err else SkyflowMessages.Error.GENERIC_API_ERROR.value + log_and_reject_error(description, status, request_id, logger=logger) def encode_column_values(get_request): encoded_column_values = list() diff --git a/skyflow/utils/_version.py b/skyflow/utils/_version.py index 0d05fc30..35b3f3a2 100644 --- a/skyflow/utils/_version.py +++ b/skyflow/utils/_version.py @@ -1 +1 @@ -SDK_VERSION = '2.0.0' \ No newline at end of file +SDK_VERSION = '2.0.2.dev0+961116e' diff --git a/skyflow/utils/constants.py b/skyflow/utils/constants.py index ef20faf8..05d28380 100644 --- a/skyflow/utils/constants.py +++ b/skyflow/utils/constants.py @@ -1,4 +1,291 @@ OPTIONAL_TOKEN='token' PROTOCOL='https' SKY_META_DATA_HEADER='sky-metadata' +CTX_KEY_REGEX=r'^[a-zA-Z0-9_]+$' +class SKYFLOW: + SKYFLOW_ID = 'skyflowId' + X_SKYFLOW_AUTHORIZATION = 'x-skyflow-authorization' + + +class HttpHeader: + CONTENT_TYPE = 'Content-Type' + CONTENT_TYPE_LOWERCASE = 'content-type' + X_REQUEST_ID = 'x-request-id' + ERROR_FROM_CLIENT = 'error-from-client' + AUTHORIZATION = 'Authorization' + X_SKYFLOW_AUTHORIZATION_HEADER = 'X-Skyflow-Authorization' + + +class HttpStatusCode: + OK = 200 + BAD_REQUEST = 400 + UNAUTHORIZED = 401 + INTERNAL_SERVER_ERROR = 500 + + +class ContentType: + APPLICATION_JSON = 'application/json' + APPLICATION_X_WWW_FORM_URLENCODED = 'application/x-www-form-urlencoded' + TEXT_PLAIN = 'text/plain' + + +class DetectStatus: + IN_PROGRESS = 'IN_PROGRESS' + SUCCESS = 'SUCCESS' + FAILED = 'FAILED' + UNKNOWN = 'UNKNOWN' + +class Detect: + WAIT_TIME = 64 + +class FileExtension: + JSON = 'json' + MP3 = 'mp3' + WAV = 'wav' + PDF = 'pdf' + TXT = 'txt' + DOC = 'doc' + DOCX = 'docx' + JPG = 'jpg' + JPEG = 'jpeg' + PNG = 'png' + BMP = 'bmp' + TIF = 'tif' + TIFF = 'tiff' + PPT = 'ppt' + PPTX = 'pptx' + CSV = 'csv' + XLS = 'xls' + XLSX = 'xlsx' + XML = 'xml' + + +class FileProcessing: + PROCESSED_PREFIX = 'processed-' + DEIDENTIFIED_PREFIX = 'deidentified.' + ENTITIES = 'entities' + + +class EncodingType: + UTF8 = 'utf8' + UTF_8 = 'utf-8' + BASE64 = 'base64' + BINARY = 'binary' + + +class JWT: + ALGORITHM_RS256 = 'RS256' + GRANT_TYPE_JWT_BEARER = 'urn:ietf:params:oauth:grant-type:jwt-bearer' + ISSUER_SDK = 'sdk' + SIGNED_TOKEN_PREFIX = 'signed_token_' + ROLE_PREFIX = 'role:' + + +class ApiKey: + SKY_PREFIX = 'sky-' + LENGTH = 42 + + +class UrlProtocol: + HTTPS = 'https' + HTTP = 'http' + + +class BooleanString: + TRUE = 'true' + FALSE = 'false' + + +class ResponseField: + STATUS = 'Status' + BODY = 'Body' + RECORDS = 'records' + TOKENS = 'tokens' + ERROR = 'error' + SKYFLOW_ID = 'skyflow_id' + REQUEST_INDEX = 'request_index' + REQUEST_ID = 'request_id' + HTTP_CODE = 'http_code' + HTTP_STATUS = 'http_status' + GRPC_CODE = 'grpc_code' + DETAILS = 'details' + MESSAGE = 'message' + ERROR_FROM_CLIENT = 'error_from_client' + TOKEN = 'token' + VALUE = 'value' + TYPE = 'type' + TOKENIZED_DATA = 'tokenized_data' + SIGNED_TOKEN = 'signed_token' + RESPONSES = 'responses' + + +class CredentialField: + PRIVATE_KEY = 'privateKey' + CLIENT_ID = 'clientID' + KEY_ID = 'keyID' + TOKEN_URI = 'tokenURI' + TOKEN_URI_OPTION = 'token_uri' + CLIENT_NAME = 'clientName' + CREDENTIALS_STRING = 'credentials_string' + API_KEY = 'api_key' + TOKEN = 'token' + PATH = 'path' + CONTEXT = 'context' + ROLES = 'roles' + + +class JwtField: + ISS = 'iss' + KEY = 'key' + AUD = 'aud' + SUB = 'sub' + EXP = 'exp' + CTX = 'ctx' + TOK = 'tok' + IAT = 'iat' + + +class OptionField: + ROLE_IDS = 'role_ids' + DATA_TOKENS = 'data_tokens' + TIME_TO_LIVE = 'time_to_live' + ROLES = 'roles' + CTX = 'ctx' + VAULT_ID = 'vault_id' + CONNECTION_ID = 'connection_id' + CONNECTION_URL = 'connection_url' + VAULT_CLIENT = 'vault_client' + VAULT_CONTROLLER = 'vault_controller' + DETECT_CONTROLLER = 'detect_controller' + CONTROLLER = 'controller' + VERIFY_SIGNATURE = 'verify_signature' + VERIFY_AUD = 'verify_aud' + + +class ConfigField: + CREDENTIALS = 'credentials' + CLUSTER_ID = 'cluster_id' + ENV = 'env' + VAULT_ID = 'vault_id' + + +class RequestParameter: + VALUE = 'value' + COLUMN_GROUP = 'column_group' + REDACTION = 'redaction' + REDACTION_TYPE = 'redaction_type' + + +class FileUploadField: + TABLE = 'table' + SKYFLOW_ID = 'skyflow_id' + COLUMN_NAME = 'column_name' + FILE_PATH = 'file_path' + BASE64 = 'base64' + FILE_OBJECT = 'file_object' + FILE_NAME = 'file_name' + FILE = 'file' + NAME = 'name' + + +class DeidentifyFileRequestField: + ENTITIES = 'entities' + ALLOW_REGEX_LIST = 'allow_regex_list' + RESTRICT_REGEX_LIST = 'restrict_regex_list' + OUTPUT_PROCESSED_IMAGE = 'output_processed_image' + OUTPUT_OCR_TEXT = 'output_ocr_text' + MASKING_METHOD = 'masking_method' + PIXEL_DENSITY = 'pixel_density' + DENSITY = 'density' + MAX_RESOLUTION = 'max_resolution' + OUTPUT_PROCESSED_AUDIO = 'output_processed_audio' + OUTPUT_TRANSCRIPTION = 'output_transcription' + BLEEP = 'bleep' + OUTPUT_DIRECTORY = 'output_directory' + WAIT_TIME = 'wait_time' + + +class DeidentifyField: + TEXT = 'text' + ENTITY_TYPES = 'entity_types' + TOKEN_TYPE = 'token_type' + ALLOW_REGEX = 'allow_regex' + RESTRICT_REGEX = 'restrict_regex' + TRANSFORMATIONS = 'transformations' + FORMAT = 'format' + OUTPUT = 'output' + STATUS = 'status' + RUN_ID = 'run_id' + WORD_CHARACTER_COUNT = 'word_character_count' + WORD_COUNT = 'word_count' + CHARACTER_COUNT = 'character_count' + SIZE = 'size' + DURATION = 'duration' + PAGES = 'pages' + SLIDES = 'slides' + PROCESSED_FILE = 'processed_file' + PROCESSED_FILE_TYPE = 'processed_file_type' + PROCESSED_FILE_EXTENSION = 'processed_file_extension' + REDACTED_FILE = 'redacted_file' + SHIFT_DATES = 'shift_dates' + DEFAULT = 'default' + ENTITY_UNQ_COUNTER = 'entity_unq_counter' + ENTITY_UNIQUE_COUNTER = 'entity_unique_counter' + ENTITY_ONLY = 'entity_only' + VAULT_TOKEN = 'vault_token' + ENTITIES = 'entities' + MAX_DAYS = 'max_days' + MIN_DAYS = 'min_days' + MAX = 'max' + MIN = 'min' + FILE = 'file' + TYPE = 'type' + EXTENSION = 'extension' + IN_PROGRESS = 'IN_PROGRESS' + REQUEST_OPTIONS = 'request_options' + BLEEP_GAIN = 'bleep_gain' + BLEEP_FREQUENCY = 'bleep_frequency' + BLEEP_START_PADDING = 'bleep_start_padding' + BLEEP_STOP_PADDING = 'bleep_stop_padding' + DENSITY = 'density' + TOKEN_FORMAT = 'token_format' + PROCESSED_FILE_RESPONSE_KEY = 'processedFile' + PROCESSED_FILE_TYPE_RESPONSE_KEY = 'processedFileType' + PROCESSED_FILE_EXTENSION_RESPONSE_KEY = 'processedFileExtension' + + +class RequestOperation: + INSERT = 'INSERT' + DELETE = 'DELETE' + GET = 'GET' + UPDATE = 'UPDATE' + QUERY = 'QUERY' + TOKENIZE = 'TOKENIZE' + DETOKENIZE = 'DETOKENIZE' + FILE_UPLOAD = 'FILE_UPLOAD' + + +class ConfigType: + VAULT = 'vault' + CONNECTION = 'connection' + + +class SqlCommand: + SELECT = 'SELECT' + + +class SdkPrefix: + SKYFLOW_PYTHON = 'skyflow-python@' + PYTHON_RUNTIME = 'Python ' + + +class SdkMetricsKey: + SDK_NAME_VERSION = 'sdk_name_version' + SDK_CLIENT_DEVICE_MODEL = 'sdk_client_device_model' + SDK_CLIENT_OS_DETAILS = 'sdk_client_os_details' + SDK_RUNTIME_DETAILS = 'sdk_runtime_details' + + +class ErrorDefaults: + UNKNOWN_REQUEST_ID = 'unknown-request-id' diff --git a/skyflow/utils/enums/content_types.py b/skyflow/utils/enums/content_types.py index 362c286a..f2db5b92 100644 --- a/skyflow/utils/enums/content_types.py +++ b/skyflow/utils/enums/content_types.py @@ -5,4 +5,5 @@ class ContentType(Enum): PLAINTEXT = 'text/plain' XML = 'text/xml' URLENCODED = 'application/x-www-form-urlencoded' - FORMDATA = 'multipart/form-data' \ No newline at end of file + FORMDATA = 'multipart/form-data' + HTML = 'text/html' \ No newline at end of file diff --git a/skyflow/utils/enums/detect_output_transcriptions.py b/skyflow/utils/enums/detect_output_transcriptions.py index 4e14f911..a398a3d8 100644 --- a/skyflow/utils/enums/detect_output_transcriptions.py +++ b/skyflow/utils/enums/detect_output_transcriptions.py @@ -4,4 +4,5 @@ class DetectOutputTranscriptions(Enum): DIARIZED_TRANSCRIPTION = "diarized_transcription" MEDICAL_DIARIZED_TRANSCRIPTION = "medical_diarized_transcription" MEDICAL_TRANSCRIPTION = "medical_transcription" - TRANSCRIPTION = "transcription" \ No newline at end of file + TRANSCRIPTION = "transcription" + PLAINTEXT_TRANSCRIPTION = "plaintext_transcription" \ No newline at end of file diff --git a/skyflow/utils/logger/__init__.py b/skyflow/utils/logger/__init__.py index 2993b8fc..bce55608 100644 --- a/skyflow/utils/logger/__init__.py +++ b/skyflow/utils/logger/__init__.py @@ -1,2 +1,2 @@ from ._logger import Logger -from ._log_helpers import log_error, log_info, log_error_log \ No newline at end of file +from ._log_helpers import log_error, log_info, log_warn, log_error_log, set_active_log_level \ No newline at end of file diff --git a/skyflow/utils/logger/_log_helpers.py b/skyflow/utils/logger/_log_helpers.py index fdb11ea9..1343b55f 100644 --- a/skyflow/utils/logger/_log_helpers.py +++ b/skyflow/utils/logger/_log_helpers.py @@ -1,5 +1,13 @@ from ..enums import LogLevel from . import Logger +from ..constants import ResponseField + +_active_log_level = LogLevel.ERROR + + +def set_active_log_level(level): + global _active_log_level + _active_log_level = level def log_info(message, logger = None): @@ -8,6 +16,11 @@ def log_info(message, logger = None): logger.info(message) +def log_warn(message, logger=None): + if not logger: + logger = Logger(_active_log_level) + logger.warn(message) + def log_error_log(message, logger=None): if not logger: logger = Logger(LogLevel.ERROR) @@ -18,17 +31,17 @@ def log_error(message, http_code, request_id=None, grpc_code=None, http_status=N logger = Logger(LogLevel.ERROR) log_data = { - 'http_code': http_code, - 'message': message + ResponseField.HTTP_CODE: http_code, + ResponseField.MESSAGE: message } if grpc_code is not None: - log_data['grpc_code'] = grpc_code + log_data[ResponseField.GRPC_CODE] = grpc_code if http_status is not None: - log_data['http_status'] = http_status + log_data[ResponseField.HTTP_STATUS] = http_status if request_id is not None: - log_data['request_id'] = request_id + log_data[ResponseField.REQUEST_ID] = request_id if details is not None: - log_data['details'] = details + log_data[ResponseField.DETAILS] = details logger.error(log_data) \ No newline at end of file diff --git a/skyflow/utils/validations/_validations.py b/skyflow/utils/validations/_validations.py index f3428f45..42abe188 100644 --- a/skyflow/utils/validations/_validations.py +++ b/skyflow/utils/validations/_validations.py @@ -6,62 +6,83 @@ MaskingMethod from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages -from skyflow.utils.logger import log_info, log_error_log +from skyflow.utils.constants import ( + ApiKey, ResponseField, RequestParameter, + FileUploadField, + DeidentifyFileRequestField, RequestOperation, ConfigType, SqlCommand, ConfigField, OptionField, CredentialField, Detect +) +from skyflow.utils.logger import log_info, log_warn, log_error_log from skyflow.vault.detect import DeidentifyTextRequest, ReidentifyTextRequest, TokenFormat, Transformations, \ GetDetectRunRequest, Bleep, DeidentifyFileRequest from skyflow.vault.detect._file_input import FileInput - -valid_vault_config_keys = ["vault_id", "cluster_id", "credentials", "env"] -valid_connection_config_keys = ["connection_id", "connection_url", "credentials"] -valid_credentials_keys = ["path", "roles", "context", "token", "credentials_string"] +from skyflow.utils._helpers import is_valid_url + +valid_vault_config_keys = [ + ConfigField.VAULT_ID, + ConfigField.CLUSTER_ID, + ConfigField.CREDENTIALS, + ConfigField.ENV +] +valid_connection_config_keys = [ + OptionField.CONNECTION_ID, + OptionField.CONNECTION_URL, + ConfigField.CREDENTIALS +] +valid_credentials_keys = [ + CredentialField.PATH, + CredentialField.ROLES, + CredentialField.CONTEXT, + CredentialField.TOKEN, + CredentialField.CREDENTIALS_STRING +] invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value def validate_required_field(logger, config, field_name, expected_type, empty_error, invalid_error): field_value = config.get(field_name) if field_name not in config or not isinstance(field_value, expected_type): - if field_name == "vault_id": - logger.error(SkyflowMessages.ErrorLogs.VAULTID_IS_REQUIRED.value) - if field_name == "cluster_id": - logger.error(SkyflowMessages.ErrorLogs.CLUSTER_ID_IS_REQUIRED.value) - if field_name == "connection_id": - logger.error(SkyflowMessages.ErrorLogs.CONNECTION_ID_IS_REQUIRED.value) - if field_name == "connection_url": - logger.error(SkyflowMessages.ErrorLogs.INVALID_CONNECTION_URL.value) + if field_name == ConfigField.VAULT_ID: + log_error_log(SkyflowMessages.ErrorLogs.VAULTID_IS_REQUIRED.value, logger) + if field_name == ConfigField.CLUSTER_ID: + log_error_log(SkyflowMessages.ErrorLogs.CLUSTER_ID_IS_REQUIRED.value, logger) + if field_name == OptionField.CONNECTION_ID: + log_error_log(SkyflowMessages.ErrorLogs.CONNECTION_ID_IS_REQUIRED.value, logger) + if field_name == OptionField.CONNECTION_URL: + log_error_log(SkyflowMessages.ErrorLogs.INVALID_CONNECTION_URL.value, logger) raise SkyflowError(invalid_error, invalid_input_error_code) if isinstance(field_value, str) and not field_value.strip(): - if field_name == "vault_id": - logger.error(SkyflowMessages.ErrorLogs.EMPTY_VAULTID.value) - if field_name == "cluster_id": - logger.error(SkyflowMessages.ErrorLogs.EMPTY_CLUSTER_ID.value) - if field_name == "connection_id": - logger.error(SkyflowMessages.ErrorLogs.EMPTY_CONNECTION_ID.value) - if field_name == "connection_url": - logger.error(SkyflowMessages.ErrorLogs.EMPTY_CONNECTION_URL.value) - if field_name == "path": - logger.error(SkyflowMessages.ErrorLogs.EMPTY_CREDENTIALS_PATH.value) - if field_name == "credentials_string": - logger.error(SkyflowMessages.ErrorLogs.EMPTY_CREDENTIALS_STRING.value) - if field_name == "token": - logger.error(SkyflowMessages.ErrorLogs.EMPTY_TOKEN_VALUE.value) - if field_name == "api_key": - logger.error(SkyflowMessages.ErrorLogs.EMPTY_API_KEY_VALUE.value) + if field_name == ConfigField.VAULT_ID: + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_VAULTID.value, logger) + if field_name == ConfigField.CLUSTER_ID: + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_CLUSTER_ID.value, logger) + if field_name == OptionField.CONNECTION_ID: + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_CONNECTION_ID.value, logger) + if field_name == OptionField.CONNECTION_URL: + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_CONNECTION_URL.value, logger) + if field_name == CredentialField.PATH: + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_CREDENTIALS_PATH.value, logger) + if field_name == CredentialField.CREDENTIALS_STRING: + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_CREDENTIALS_STRING.value, logger) + if field_name == CredentialField.TOKEN: + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKEN_VALUE.value, logger) + if field_name == CredentialField.API_KEY: + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_API_KEY_VALUE.value, logger) raise SkyflowError(empty_error, invalid_input_error_code) def validate_api_key(api_key: str, logger = None) -> bool: - if not api_key.startswith('sky-'): + if not api_key.startswith(ApiKey.SKY_PREFIX): log_error_log(SkyflowMessages.ErrorLogs.INVALID_API_KEY.value, logger=logger) return False - if len(api_key) != 42: + if len(api_key) != ApiKey.LENGTH: log_error_log(SkyflowMessages.ErrorLogs.INVALID_API_KEY.value, logger = logger) return False return True def validate_credentials(logger, credentials, config_id_type=None, config_id=None): - key_present = [k for k in ["path", "token", "credentials_string", "api_key"] if credentials.get(k)] + key_present = [k for k in [CredentialField.PATH, CredentialField.TOKEN, CredentialField.CREDENTIALS_STRING, CredentialField.API_KEY] if credentials.get(k)] if len(key_present) == 0: error_message = ( @@ -69,6 +90,7 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS.value ) + log_error_log(error_message, logger) raise SkyflowError(error_message, invalid_input_error_code) elif len(key_present) > 1: error_message = ( @@ -76,79 +98,90 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non if config_id_type and config_id else SkyflowMessages.Error.MULTIPLE_CREDENTIALS_PASSED.value ) + log_error_log(error_message, logger) raise SkyflowError(error_message, invalid_input_error_code) - if "roles" in credentials: + if CredentialField.ROLES in credentials: validate_required_field( - logger, credentials, "roles", list, + logger, credentials, CredentialField.ROLES, list, SkyflowMessages.Error.INVALID_ROLES_KEY_TYPE_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_ROLES_KEY_TYPE.value, SkyflowMessages.Error.EMPTY_ROLES_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.EMPTY_ROLES.value ) - if "context" in credentials: + if CredentialField.CONTEXT in credentials: validate_required_field( - logger, credentials, "context", str, + logger, credentials, CredentialField.CONTEXT, str, SkyflowMessages.Error.EMPTY_CONTEXT_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.EMPTY_CONTEXT.value, SkyflowMessages.Error.INVALID_CONTEXT_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_CONTEXT.value ) - if "credentials_string" in credentials: + if CredentialField.CREDENTIALS_STRING in credentials: validate_required_field( - logger, credentials, "credentials_string", str, + logger, credentials, CredentialField.CREDENTIALS_STRING, str, SkyflowMessages.Error.EMPTY_CREDENTIALS_STRING_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.EMPTY_CREDENTIALS_STRING.value, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value ) - elif "path" in credentials: + elif CredentialField.PATH in credentials: validate_required_field( - logger, credentials, "path", str, + logger, credentials, CredentialField.PATH, str, SkyflowMessages.Error.EMPTY_CREDENTIAL_FILE_PATH_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.EMPTY_CREDENTIAL_FILE_PATH.value, SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value ) - elif "token" in credentials: + elif CredentialField.TOKEN in credentials: validate_required_field( - logger, credentials, "token", str, + logger, credentials, CredentialField.TOKEN, str, SkyflowMessages.Error.EMPTY_CREDENTIALS_TOKEN.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.EMPTY_CREDENTIALS_TOKEN.value, SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value ) - if is_expired(credentials.get("token"), logger): + if is_expired(credentials.get(CredentialField.TOKEN), logger): + log_error_log(SkyflowMessages.ErrorLogs.INVALID_BEARER_TOKEN.value, logger) raise SkyflowError( - SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value.format(config_id_type, config_id) - if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value, + SkyflowMessages.Error.EXPIRED_BEARER_TOKEN.value + if config_id_type and config_id else SkyflowMessages.Error.EXPIRED_BEARER_TOKEN.value, invalid_input_error_code ) - elif "api_key" in credentials: + elif CredentialField.API_KEY in credentials: validate_required_field( - logger, credentials, "api_key", str, + logger, credentials, CredentialField.API_KEY, str, SkyflowMessages.Error.EMPTY_API_KEY.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.EMPTY_API_KEY.value, SkyflowMessages.Error.INVALID_API_KEY.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_API_KEY.value ) - if not validate_api_key(credentials.get("api_key"), logger): + if not validate_api_key(credentials.get(CredentialField.API_KEY), logger): raise SkyflowError(SkyflowMessages.Error.INVALID_API_KEY.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_API_KEY.value, invalid_input_error_code) + + if CredentialField.TOKEN_URI_OPTION in credentials: + token_uri = credentials.get(CredentialField.TOKEN_URI_OPTION) + if ( + token_uri is None + or not isinstance(token_uri, str) + or not is_valid_url(token_uri) + ): + log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value, logger) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code) def validate_log_level(logger, log_level): if not isinstance(log_level, LogLevel): - raise SkyflowError( SkyflowMessages.Error.INVALID_LOG_LEVEL.value, invalid_input_error_code) - - if log_level is None: - raise SkyflowError(SkyflowMessages.Error.EMPTY_LOG_LEVEL.value, invalid_input_error_code) + log_error_log(SkyflowMessages.ErrorLogs.INVALID_LOG_LEVEL.value, logger) + raise SkyflowError(SkyflowMessages.Error.INVALID_LOG_LEVEL.value, invalid_input_error_code) def validate_keys(logger, config, config_keys): for key in config.keys(): if key not in config_keys: + log_error_log(SkyflowMessages.ErrorLogs.INVALID_KEY.value.format(key), logger) raise SkyflowError(SkyflowMessages.Error.INVALID_KEY.value.format(key), invalid_input_error_code) def validate_vault_config(logger, config): @@ -157,28 +190,28 @@ def validate_vault_config(logger, config): # Validate vault_id (string, not empty) validate_required_field( - logger, config, "vault_id", str, + logger, config, ConfigField.VAULT_ID, str, SkyflowMessages.Error.EMPTY_VAULT_ID.value, SkyflowMessages.Error.INVALID_VAULT_ID.value ) - vault_id = config.get("vault_id") + vault_id = config.get(ConfigField.VAULT_ID) # Validate cluster_id (string, not empty) validate_required_field( - logger, config, "cluster_id", str, + logger, config, ConfigField.CLUSTER_ID, str, SkyflowMessages.Error.EMPTY_CLUSTER_ID.value.format(vault_id), SkyflowMessages.Error.INVALID_CLUSTER_ID.value.format(vault_id) ) # Validate credentials (dict, not empty) - if "credentials" in config and not config.get("credentials"): - raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("vault", vault_id), invalid_input_error_code) + if ConfigField.CREDENTIALS in config and not config.get(ConfigField.CREDENTIALS): + raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format(ConfigType.VAULT, vault_id), invalid_input_error_code) - if "credentials" in config and config.get("credentials"): - validate_credentials(logger, config.get("credentials"), "vault", vault_id) + if ConfigField.CREDENTIALS in config and config.get(ConfigField.CREDENTIALS): + validate_credentials(logger, config.get(ConfigField.CREDENTIALS), ConfigType.VAULT, vault_id) # Validate env (optional, should be one of LogLevel values) - if "env" in config and config.get("env") not in Env: - logger.error(SkyflowMessages.ErrorLogs.VAULTID_IS_REQUIRED.value) + if ConfigField.ENV in config and config.get(ConfigField.ENV) not in Env: + log_error_log(SkyflowMessages.ErrorLogs.ENV_IS_REQUIRED.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_ENV.value.format(vault_id), invalid_input_error_code) return True @@ -189,23 +222,23 @@ def validate_update_vault_config(logger, config): # Validate vault_id (string, not empty) validate_required_field( - logger, config, "vault_id", str, + logger, config, ConfigField.VAULT_ID, str, SkyflowMessages.Error.EMPTY_VAULT_ID.value, SkyflowMessages.Error.INVALID_VAULT_ID.value ) - vault_id = config.get("vault_id") + vault_id = config.get(ConfigField.VAULT_ID) - if "cluster_id" in config and not config.get("cluster_id"): + if ConfigField.CLUSTER_ID in config and not config.get(ConfigField.CLUSTER_ID): raise SkyflowError(SkyflowMessages.Error.INVALID_CLUSTER_ID.value.format(vault_id), invalid_input_error_code) - if "env" in config and config.get("env") not in Env: + if ConfigField.ENV in config and config.get(ConfigField.ENV) not in Env: raise SkyflowError(SkyflowMessages.Error.INVALID_ENV.value.format(vault_id), invalid_input_error_code) - if "credentials" not in config: - raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("vault", vault_id), invalid_input_error_code) + if ConfigField.CREDENTIALS not in config: + raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format(ConfigType.VAULT, vault_id), invalid_input_error_code) - validate_credentials(logger, config.get("credentials"), "vault", vault_id) + validate_credentials(logger, config.get(ConfigField.CREDENTIALS), ConfigType.VAULT, vault_id) return True @@ -214,23 +247,23 @@ def validate_connection_config(logger, config): validate_keys(logger, config, valid_connection_config_keys) validate_required_field( - logger, config, "connection_id" , str, + logger, config, OptionField.CONNECTION_ID , str, SkyflowMessages.Error.EMPTY_CONNECTION_ID.value, SkyflowMessages.Error.INVALID_CONNECTION_ID.value ) - connection_id = config.get("connection_id") + connection_id = config.get(OptionField.CONNECTION_ID) validate_required_field( - logger, config, "connection_url", str, + logger, config, OptionField.CONNECTION_URL, str, SkyflowMessages.Error.EMPTY_CONNECTION_URL.value.format(connection_id), SkyflowMessages.Error.INVALID_CONNECTION_URL.value.format(connection_id) ) - if "credentials" not in config: - raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("connection", connection_id), invalid_input_error_code) + if ConfigField.CREDENTIALS not in config: + raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format(ConfigType.CONNECTION, connection_id), invalid_input_error_code) - validate_credentials(logger, config.get("credentials"), "connection", connection_id) + validate_credentials(logger, config.get(ConfigField.CREDENTIALS), ConfigType.CONNECTION, connection_id) return True @@ -239,193 +272,218 @@ def validate_update_connection_config(logger, config): validate_keys(logger, config, valid_connection_config_keys) validate_required_field( - logger, config, "connection_id", str, + logger, config, OptionField.CONNECTION_ID, str, SkyflowMessages.Error.EMPTY_CONNECTION_ID.value, SkyflowMessages.Error.INVALID_CONNECTION_ID.value ) - connection_id = config.get("connection_id") + connection_id = config.get(OptionField.CONNECTION_ID) validate_required_field( - logger, config, "connection_url", str, + logger, config, OptionField.CONNECTION_URL, str, SkyflowMessages.Error.EMPTY_CONNECTION_URL.value.format(connection_id), SkyflowMessages.Error.INVALID_CONNECTION_URL.value.format(connection_id) ) - if "credentials" not in config: - raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("connection", connection_id), invalid_input_error_code) - validate_credentials(logger, config.get("credentials")) + if ConfigField.CREDENTIALS not in config: + raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format(ConfigType.CONNECTION, connection_id), invalid_input_error_code) + validate_credentials(logger, config.get(ConfigField.CREDENTIALS)) return True def validate_file_from_request(file_input: FileInput): if file_input is None: + log_error_log(SkyflowMessages.Error.INVALID_FILE_INPUT.value) raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_INPUT.value, invalid_input_error_code) - - has_file = hasattr(file_input, 'file') and file_input.file is not None - has_file_path = hasattr(file_input, 'file_path') and file_input.file_path is not None - + + has_file = hasattr(file_input, FileUploadField.FILE) and file_input.file is not None + has_file_path = hasattr(file_input, FileUploadField.FILE_PATH) and file_input.file_path is not None + # Must provide exactly one of file or file_path if (has_file and has_file_path) or (not has_file and not has_file_path): + log_error_log(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_INPUT.value) raise SkyflowError(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_INPUT.value, invalid_input_error_code) - + if has_file: file = file_input.file # Validate file object has required attributes - if not hasattr(file, 'name') or not isinstance(file.name, str) or not file.name.strip(): + if not hasattr(file, FileUploadField.NAME) or not isinstance(file.name, str) or not file.name.strip(): + log_error_log(SkyflowMessages.Error.INVALID_FILE_TYPE.value) raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_TYPE.value, invalid_input_error_code) - + # Validate file name file_name, _ = os.path.splitext(os.path.basename(file.name)) if not file_name or not file_name.strip(): + log_error_log(SkyflowMessages.Error.INVALID_FILE_NAME.value) raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_NAME.value, invalid_input_error_code) - + elif has_file_path: file_path = file_input.file_path if not isinstance(file_path, str) or not file_path.strip(): + log_error_log(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_PATH.value) raise SkyflowError(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_PATH.value, invalid_input_error_code) - + if not os.path.exists(file_path) or not os.path.isfile(file_path): + log_error_log(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_PATH.value) raise SkyflowError(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_PATH.value, invalid_input_error_code) def validate_deidentify_file_request(logger, request: DeidentifyFileRequest): - if not hasattr(request, 'file') or request.file is None: + if not hasattr(request, FileUploadField.FILE) or request.file is None: + log_error_log(SkyflowMessages.Error.INVALID_FILE_INPUT.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_INPUT.value, invalid_input_error_code) - + # Validate file input first validate_file_from_request(request.file) # Optional: entities - if hasattr(request, 'entities') and request.entities is not None: + if hasattr(request, DeidentifyFileRequestField.ENTITIES) and request.entities is not None: if not isinstance(request.entities, list): + log_error_log(SkyflowMessages.Error.INVALID_DETECT_ENTITIES_TYPE.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_DETECT_ENTITIES_TYPE.value, invalid_input_error_code) if not all(isinstance(entity, DetectEntities) for entity in request.entities): + log_error_log(SkyflowMessages.Error.INVALID_DETECT_ENTITIES_TYPE.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_DETECT_ENTITIES_TYPE.value, invalid_input_error_code) # Optional: allow_regex_list - if hasattr(request, 'allow_regex_list') and request.allow_regex_list is not None: + if hasattr(request, DeidentifyFileRequestField.ALLOW_REGEX_LIST) and request.allow_regex_list is not None: if not isinstance(request.allow_regex_list, list) or not all(isinstance(x, str) for x in request.allow_regex_list): + log_error_log(SkyflowMessages.Error.INVALID_ALLOW_REGEX_LIST.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_ALLOW_REGEX_LIST.value, invalid_input_error_code) # Optional: restrict_regex_list - if hasattr(request, 'restrict_regex_list') and request.restrict_regex_list is not None: + if hasattr(request, DeidentifyFileRequestField.RESTRICT_REGEX_LIST) and request.restrict_regex_list is not None: if not isinstance(request.restrict_regex_list, list) or not all(isinstance(x, str) for x in request.restrict_regex_list): + log_error_log(SkyflowMessages.Error.INVALID_RESTRICT_REGEX_LIST.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_RESTRICT_REGEX_LIST.value, invalid_input_error_code) # Optional: token_format if request.token_format is not None and not isinstance(request.token_format, TokenFormat): + log_error_log(SkyflowMessages.Error.INVALID_TOKEN_FORMAT.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_FORMAT.value, invalid_input_error_code) # Optional: transformations if request.transformations is not None and not isinstance(request.transformations, Transformations): + log_error_log(SkyflowMessages.Error.INVALID_TRANSFORMATIONS.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TRANSFORMATIONS.value, invalid_input_error_code) # Optional: output_processed_image - if hasattr(request, 'output_processed_image') and request.output_processed_image is not None: + if hasattr(request, DeidentifyFileRequestField.OUTPUT_PROCESSED_IMAGE) and request.output_processed_image is not None: if not isinstance(request.output_processed_image, bool): + log_error_log(SkyflowMessages.Error.INVALID_OUTPUT_PROCESSED_IMAGE.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_OUTPUT_PROCESSED_IMAGE.value, invalid_input_error_code) # Optional: output_ocr_text - if hasattr(request, 'output_ocr_text') and request.output_ocr_text is not None: + if hasattr(request, DeidentifyFileRequestField.OUTPUT_OCR_TEXT) and request.output_ocr_text is not None: if not isinstance(request.output_ocr_text, bool): + log_error_log(SkyflowMessages.Error.INVALID_OUTPUT_OCR_TEXT.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_OUTPUT_OCR_TEXT.value, invalid_input_error_code) # Optional: masking_method - # Optional: masking_method - if hasattr(request, 'masking_method') and request.masking_method is not None: + if hasattr(request, DeidentifyFileRequestField.MASKING_METHOD) and request.masking_method is not None: if not isinstance(request.masking_method, MaskingMethod): + log_error_log(SkyflowMessages.Error.INVALID_MASKING_METHOD.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_MASKING_METHOD.value, invalid_input_error_code) # Optional: pixel_density - if hasattr(request, 'pixel_density') and request.pixel_density is not None: + if hasattr(request, DeidentifyFileRequestField.PIXEL_DENSITY) and request.pixel_density is not None: if not isinstance(request.pixel_density, (int, float)): + log_error_log(SkyflowMessages.Error.INVALID_PIXEL_DENSITY.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_PIXEL_DENSITY.value, invalid_input_error_code) # Optional: max_resolution - if hasattr(request, 'max_resolution') and request.max_resolution is not None: + if hasattr(request, DeidentifyFileRequestField.MAX_RESOLUTION) and request.max_resolution is not None: if not isinstance(request.max_resolution, (int, float)): + log_error_log(SkyflowMessages.Error.INVALID_MAXIMUM_RESOLUTION.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_MAXIMUM_RESOLUTION.value, invalid_input_error_code) # Optional: output_processed_audio - if hasattr(request, 'output_processed_audio') and request.output_processed_audio is not None: + if hasattr(request, DeidentifyFileRequestField.OUTPUT_PROCESSED_AUDIO) and request.output_processed_audio is not None: if not isinstance(request.output_processed_audio, bool): + log_error_log(SkyflowMessages.Error.INVALID_OUTPUT_PROCESSED_AUDIO.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_OUTPUT_PROCESSED_AUDIO.value, invalid_input_error_code) # Optional: output_transcription - if hasattr(request, 'output_transcription') and request.output_transcription is not None: + if hasattr(request, DeidentifyFileRequestField.OUTPUT_TRANSCRIPTION) and request.output_transcription is not None: if not isinstance(request.output_transcription, DetectOutputTranscriptions): + log_error_log(SkyflowMessages.Error.INVALID_OUTPUT_TRANSCRIPTION.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_OUTPUT_TRANSCRIPTION.value, invalid_input_error_code) # Optional: bleep - if hasattr(request, 'bleep') and request.bleep is not None: + if hasattr(request, DeidentifyFileRequestField.BLEEP) and request.bleep is not None: if not isinstance(request.bleep, Bleep): + log_error_log(SkyflowMessages.Error.INVALID_BLEEP_TYPE.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_BLEEP_TYPE.value, invalid_input_error_code) - + # Validate gain if request.bleep.gain is not None and not isinstance(request.bleep.gain, (int, float)): + log_error_log(SkyflowMessages.Error.INVALID_BLEEP_GAIN.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_BLEEP_GAIN.value, invalid_input_error_code) - + # Validate frequency if request.bleep.frequency is not None and not isinstance(request.bleep.frequency, (int, float)): + log_error_log(SkyflowMessages.Error.INVALID_BLEEP_FREQUENCY.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_BLEEP_FREQUENCY.value, invalid_input_error_code) - + # Validate start_padding if request.bleep.start_padding is not None and not isinstance(request.bleep.start_padding, (int, float)): + log_error_log(SkyflowMessages.Error.INVALID_BLEEP_START_PADDING.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_BLEEP_START_PADDING.value, invalid_input_error_code) - + # Validate stop_padding if request.bleep.stop_padding is not None and not isinstance(request.bleep.stop_padding, (int, float)): + log_error_log(SkyflowMessages.Error.INVALID_BLEEP_STOP_PADDING.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_BLEEP_STOP_PADDING.value, invalid_input_error_code) # Optional: output_directory - if hasattr(request, 'output_directory') and request.output_directory is not None: + if hasattr(request, DeidentifyFileRequestField.OUTPUT_DIRECTORY) and request.output_directory is not None: if not isinstance(request.output_directory, str): + log_error_log(SkyflowMessages.Error.INVALID_OUTPUT_DIRECTORY_VALUE.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_OUTPUT_DIRECTORY_VALUE.value, invalid_input_error_code) if not os.path.isdir(request.output_directory): + log_error_log(SkyflowMessages.Error.OUTPUT_DIRECTORY_NOT_FOUND.value.format(request.output_directory), logger) raise SkyflowError(SkyflowMessages.Error.OUTPUT_DIRECTORY_NOT_FOUND.value.format(request.output_directory), invalid_input_error_code) # Optional: wait_time - if hasattr(request, 'wait_time') and request.wait_time is not None: + if hasattr(request, DeidentifyFileRequestField.WAIT_TIME) and request.wait_time is not None: if not isinstance(request.wait_time, (int, float)): + log_error_log(SkyflowMessages.Error.INVALID_WAIT_TIME.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_WAIT_TIME.value, invalid_input_error_code) - if request.wait_time < 0 and request.wait_time > 64: + if request.wait_time < 0 or request.wait_time > Detect.WAIT_TIME: + log_error_log(SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value, logger) raise SkyflowError(SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value, invalid_input_error_code) def validate_insert_request(logger, request): if not isinstance(request.table, str): - log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format("INSERT"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format(RequestOperation.INSERT), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_NAME_IN_INSERT.value, invalid_input_error_code) if not request.table.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format("INSERT"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format(RequestOperation.INSERT), logger = logger) raise SkyflowError(SkyflowMessages.Error.MISSING_TABLE_NAME_IN_INSERT.value, invalid_input_error_code) if not isinstance(request.values, list) or not all(isinstance(v, dict) for v in request.values): - log_error_log(SkyflowMessages.ErrorLogs.VALUES_IS_REQUIRED.value.format("INSERT"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.VALUES_IS_REQUIRED.value.format(RequestOperation.INSERT), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TYPE_OF_DATA_IN_INSERT.value, invalid_input_error_code) - if not len(request.values): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_VALUES.value.format("INSERT"), logger=logger) + if not request.values: + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_VALUES.value.format(RequestOperation.INSERT), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_DATA_IN_INSERT.value, invalid_input_error_code) for i, item in enumerate(request.values, start=1): for key, value in item.items(): if key is None or key == "": - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_VALUES.value.format("INSERT"), logger = logger) - - if value is None or value == "": - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_VALUE_IN_VALUES.value.format("INSERT", key), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_VALUES.value.format(RequestOperation.INSERT), logger = logger) if request.upsert is not None and (not isinstance(request.upsert, str) or not request.upsert.strip()): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_UPSERT.value("INSERT"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_UPSERT.value.format(RequestOperation.INSERT), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_UPSERT_OPTIONS_TYPE.value, invalid_input_error_code) if request.homogeneous is not None and not isinstance(request.homogeneous, bool): raise SkyflowError(SkyflowMessages.Error.INVALID_HOMOGENEOUS_TYPE.value, invalid_input_error_code) if request.upsert and request.homogeneous: - log_error_log(SkyflowMessages.ErrorLogs.HOMOGENOUS_NOT_SUPPORTED_WITH_UPSERT.value.format("INSERT"), logger = logger) - raise SkyflowError(SkyflowMessages.Error.HOMOGENOUS_NOT_SUPPORTED_WITH_UPSERT.value.format("INSERT"), invalid_input_error_code) + log_error_log(SkyflowMessages.ErrorLogs.HOMOGENOUS_NOT_SUPPORTED_WITH_UPSERT.value.format(RequestOperation.INSERT), logger = logger) + raise SkyflowError(SkyflowMessages.Error.HOMOGENOUS_NOT_SUPPORTED_WITH_UPSERT.value.format(RequestOperation.INSERT), invalid_input_error_code) if request.token_mode is not None: if not isinstance(request.token_mode, TokenMode): @@ -441,15 +499,15 @@ def validate_insert_request(logger, request): for i, item in enumerate(request.tokens, start=1): for key, value in item.items(): if key is None or key == "": - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_VALUE_IN_TOKENS.value.format("INSERT"), + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_VALUE_IN_TOKENS.value.format(RequestOperation.INSERT), logger=logger) if value is None or value == "": - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_TOKENS.value.format("INSERT", key), + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_TOKENS.value.format(RequestOperation.INSERT, key), logger=logger) if not isinstance(request.tokens, list) or not request.tokens or not all( isinstance(t, dict) for t in request.tokens): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value("INSERT"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value.format(RequestOperation.INSERT), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TYPE_OF_DATA_IN_INSERT.value, invalid_input_error_code) if request.token_mode == TokenMode.ENABLE and not request.tokens: @@ -459,43 +517,43 @@ def validate_insert_request(logger, request): raise SkyflowError(SkyflowMessages.Error.TOKENS_PASSED_FOR_TOKEN_MODE_DISABLE.value, invalid_input_error_code) if request.token_mode == TokenMode.ENABLE_STRICT: - if len(request.values) != len(request.tokens): - log_error_log(SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format("INSERT"), logger = logger) + if not request.tokens or len(request.values) != len(request.tokens): + log_error_log(SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format(RequestOperation.INSERT), logger = logger) raise SkyflowError(SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT.value, invalid_input_error_code) for v, t in zip(request.values, request.tokens): if set(v.keys()) != set(t.keys()): - log_error_log(SkyflowMessages.ErrorLogs.MISMATCH_OF_FIELDS_AND_TOKENS.value.format("INSERT"), logger=logger) - raise SkyflowError(SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT.value, invalid_input_error_code) + log_error_log(SkyflowMessages.ErrorLogs.MISMATCH_OF_FIELDS_AND_TOKENS.value.format(RequestOperation.INSERT), logger=logger) + raise SkyflowError(SkyflowMessages.Error.MISMATCH_OF_FIELDS_AND_TOKENS.value, invalid_input_error_code) def validate_delete_request(logger, request): if not isinstance(request.table, str): - log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format("DELETE"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format(RequestOperation.DELETE), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code) if not request.table.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format("DELETE"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format(RequestOperation.DELETE), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TABLE_VALUE.value, invalid_input_error_code) if not request.ids: - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_IDS.value.format("DELETE"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_IDS.value.format(RequestOperation.DELETE), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_RECORD_IDS_IN_DELETE.value, invalid_input_error_code) def validate_query_request(logger, request): - if not request.query: - log_error_log(SkyflowMessages.ErrorLogs.QUERY_IS_REQUIRED.value.format("QUERY"), logger = logger) - raise SkyflowError(SkyflowMessages.Error.EMPTY_QUERY.value, invalid_input_error_code) - if not isinstance(request.query, str): query_type = str(type(request.query)) raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_TYPE.value.format(query_type), invalid_input_error_code) + if not request.query: + log_error_log(SkyflowMessages.ErrorLogs.QUERY_IS_REQUIRED.value.format(RequestOperation.QUERY), logger=logger) + raise SkyflowError(SkyflowMessages.Error.EMPTY_QUERY.value, invalid_input_error_code) + if not request.query.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_QUERY.value.format("QUERY"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_QUERY.value.format(RequestOperation.QUERY), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_QUERY.value, invalid_input_error_code) - if not request.query.upper().startswith("SELECT"): + if not request.query.upper().startswith(SqlCommand.SELECT): command = request.query - raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_COMMAND.value.format(command), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_COMMAND.value.format(command), invalid_input_error_code) def validate_get_request(logger, request): redaction_type = request.redaction_type @@ -508,23 +566,23 @@ def validate_get_request(logger, request): download_url = request.download_url if not isinstance(request.table, str): - log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format("GET"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format(RequestOperation.GET), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code) if not request.table.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format("GET"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format(RequestOperation.GET), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TABLE_VALUE.value, invalid_input_error_code) if not skyflow_ids and not column_name and not column_values: - log_error_log(SkyflowMessages.ErrorLogs.NEITHER_IDS_NOR_COLUMN_NAME_PASSED.value.format("GET"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.NEITHER_IDS_NOR_COLUMN_NAME_PASSED.value.format(RequestOperation.GET), logger = logger) if skyflow_ids and (not isinstance(skyflow_ids, list) or not skyflow_ids): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_IDS.value.format("GET"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_IDS.value.format(RequestOperation.GET), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_IDS_TYPE.value.format(type(skyflow_ids)), invalid_input_error_code) if skyflow_ids: for index, skyflow_id in enumerate(skyflow_ids): if skyflow_id is None or skyflow_id == "": - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_ID_IN_IDS.value.format("GET", index), + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_ID_IN_IDS.value.format(RequestOperation.GET, index), logger=logger) if not isinstance(request.return_tokens, bool): @@ -534,7 +592,7 @@ def validate_get_request(logger, request): raise SkyflowError(SkyflowMessages.Error.INVALID_REDACTION_TYPE.value.format(type(redaction_type)), invalid_input_error_code) if fields is not None and (not isinstance(fields, list) or not fields): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_FIELDS.value.format("GET"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_FIELDS.value.format(RequestOperation.GET), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_FIELDS_VALUE.value.format(type(fields)), invalid_input_error_code) if offset is not None and limit is not None: @@ -543,13 +601,13 @@ def validate_get_request(logger, request): invalid_input_error_code) if offset is not None and not isinstance(offset, str): - raise SkyflowError(SkyflowMessages.Error.INVALID_OFF_SET_VALUE.value(type(offset)), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_OFF_SET_VALUE.value.format(type(offset)), invalid_input_error_code) if limit is not None and not isinstance(limit, str): - raise SkyflowError(SkyflowMessages.Error.INVALID_LIMIT_VALUE.value(type(limit)), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_LIMIT_VALUE.value.format(type(limit)), invalid_input_error_code) if download_url is not None and not isinstance(download_url, bool): - raise SkyflowError(SkyflowMessages.Error.INVALID_DOWNLOAD_URL_VALUE.value(type(download_url)), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_DOWNLOAD_URL_VALUE.value.format(type(download_url)), invalid_input_error_code) if column_name is not None and (not isinstance(column_name, str) or not column_name.strip()): raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_NAME.value.format(type(column_name)), invalid_input_error_code) @@ -560,61 +618,58 @@ def validate_get_request(logger, request): raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_VALUE.value.format(type(column_values)), invalid_input_error_code) if request.return_tokens and redaction_type: - log_error_log(SkyflowMessages.ErrorLogs.TOKENIZATION_NOT_SUPPORTED_WITH_REDACTION.value.format("GET"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.TOKENIZATION_NOT_SUPPORTED_WITH_REDACTION.value.format(RequestOperation.GET), logger=logger) raise SkyflowError(SkyflowMessages.Error.REDACTION_WITH_TOKENS_NOT_SUPPORTED.value, invalid_input_error_code) if (column_name or column_values) and request.return_tokens: - log_error_log(SkyflowMessages.ErrorLogs.TOKENIZATION_SUPPORTED_ONLY_WITH_IDS.value.format("GET"), + log_error_log(SkyflowMessages.ErrorLogs.TOKENIZATION_SUPPORTED_ONLY_WITH_IDS.value.format(RequestOperation.GET), logger=logger) raise SkyflowError(SkyflowMessages.Error.TOKENS_GET_COLUMN_NOT_SUPPORTED.value, invalid_input_error_code) if column_values and not column_name: - log_error_log(SkyflowMessages.ErrorLogs.COLUMN_VALUES_IS_REQUIRED_GET.value.format("GET"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.COLUMN_VALUES_IS_REQUIRED_GET.value.format(RequestOperation.GET), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_VALUE.value.format(type(column_values)), invalid_input_error_code) if column_name and not column_values: log_error_log(SkyflowMessages.ErrorLogs.COLUMN_NAME_IS_REQUIRED.value.format("GET"), logger = logger) - SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_NAME.value.format(type(column_name)), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_VALUES.value, invalid_input_error_code) if (column_name or column_values) and skyflow_ids: - log_error_log(SkyflowMessages.ErrorLogs.BOTH_IDS_AND_COLUMN_NAME_PASSED.value.format("GET"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.BOTH_IDS_AND_COLUMN_NAME_PASSED.value.format(RequestOperation.GET), logger = logger) raise SkyflowError(SkyflowMessages.Error.BOTH_IDS_AND_COLUMN_DETAILS_SPECIFIED.value, invalid_input_error_code) def validate_update_request(logger, request): - skyflow_id = "" - field = {key: value for key, value in request.data.items() if key != "skyflow_id"} + if not isinstance(request.data, dict): + raise SkyflowError(SkyflowMessages.Error.INVALID_FIELDS_TYPE.value.format(type(request.data)), invalid_input_error_code) - try: - skyflow_id = request.data.get("skyflow_id") - except Exception: - log_error_log(SkyflowMessages.ErrorLogs.SKYFLOW_ID_IS_REQUIRED.value.format("UPDATE"), logger=logger) + if not len(request.data.items()): + raise SkyflowError(SkyflowMessages.Error.UPDATE_FIELD_KEY_ERROR.value, invalid_input_error_code) - if not skyflow_id.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_SKYFLOW_ID.value.format("UPDATE"), logger = logger) + field = {key: value for key, value in request.data.items() if key != ResponseField.SKYFLOW_ID} + + skyflow_id = request.data.get(ResponseField.SKYFLOW_ID) + if skyflow_id is None: + log_error_log(SkyflowMessages.ErrorLogs.SKYFLOW_ID_IS_REQUIRED.value.format(RequestOperation.UPDATE), logger=logger) + elif not skyflow_id.strip(): + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_SKYFLOW_ID.value.format(RequestOperation.UPDATE), logger=logger) if not isinstance(request.table, str): - log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format("UPDATE"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format(RequestOperation.UPDATE), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code) if not request.table.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format("UPDATE"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format(RequestOperation.UPDATE), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TABLE_VALUE.value, invalid_input_error_code) if not isinstance(request.return_tokens, bool): raise SkyflowError(SkyflowMessages.Error.INVALID_RETURN_TOKENS_TYPE.value, invalid_input_error_code) - if not isinstance(request.data, dict): - raise SkyflowError(SkyflowMessages.Error.INVALID_FIELDS_TYPE.value(type(request.data)), invalid_input_error_code) - - if not len(request.data.items()): - raise SkyflowError(SkyflowMessages.Error.UPDATE_FIELD_KEY_ERROR.value, invalid_input_error_code) - if request.token_mode is not None: if not isinstance(request.token_mode, TokenMode): raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_MODE_TYPE.value, invalid_input_error_code) if request.tokens: if not isinstance(request.tokens, dict) or not request.tokens: - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value.format("UPDATE"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value.format(RequestOperation.UPDATE), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TYPE_OF_DATA_IN_INSERT.value, invalid_input_error_code) if request.token_mode == TokenMode.ENABLE and not request.tokens: @@ -627,14 +682,14 @@ def validate_update_request(logger, request): if request.token_mode == TokenMode.ENABLE_STRICT: if len(field) != len(request.tokens): log_error_log( - SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format("UPDATE"), + SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format(RequestOperation.UPDATE), logger=logger) raise SkyflowError(SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT.value, invalid_input_error_code) if set(field.keys()) != set(request.tokens.keys()): log_error_log( - SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format("UPDATE"), + SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format(RequestOperation.UPDATE), logger=logger) raise SkyflowError( SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT.value, @@ -645,23 +700,33 @@ def validate_detokenize_request(logger, request): raise SkyflowError(SkyflowMessages.Error.INVALID_CONTINUE_ON_ERROR_TYPE.value, invalid_input_error_code) if not isinstance(request.data, list): - raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENS_LIST_VALUE.value(type(request.data)), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENS_LIST_VALUE.value.format(type(request.data)), invalid_input_error_code) - if not len(request.data): - log_error_log(SkyflowMessages.ErrorLogs.TOKENS_REQUIRED.value.format("DETOKENIZE"), logger = logger) - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value.format("DETOKENIZE"), logger = logger) + if not request.data: + log_error_log(SkyflowMessages.ErrorLogs.TOKENS_REQUIRED.value.format(RequestOperation.DETOKENIZE), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value.format(RequestOperation.DETOKENIZE), logger = logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENS_LIST_VALUE.value, invalid_input_error_code) for item in request.data: - if 'token' not in item: + if ResponseField.TOKEN not in item: raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENS_LIST_VALUE.value.format(type(request.data)), invalid_input_error_code) - token = item.get('token') - redaction = item.get('redaction', None) + token = item.get(ResponseField.TOKEN) + + has_redaction = RequestParameter.REDACTION in item + has_redaction_type = RequestParameter.REDACTION_TYPE in item + + if has_redaction: + log_warn(SkyflowMessages.Warning.DETOKENIZE_REDACTION_KEY_DEPRECATED.value, logger) + + if has_redaction_type: + redaction = item.get(RequestParameter.REDACTION_TYPE) + else: + redaction = item.get(RequestParameter.REDACTION, None) if not isinstance(token, str) or not token: - raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_TYPE.value.format("DETOKENIZE"), + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_TYPE.value.format(RequestOperation.DETOKENIZE), invalid_input_error_code) if redaction is not None and not isinstance(redaction, RedactionType): @@ -673,23 +738,23 @@ def validate_tokenize_request(logger, request): if not isinstance(parameters, list): raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENIZE_PARAMETERS.value.format(type(parameters)), invalid_input_error_code) - if not len(parameters): + if not parameters: raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENIZE_PARAMETERS.value, invalid_input_error_code) for i, param in enumerate(parameters): if not isinstance(param, dict): raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENIZE_PARAMETER.value.format(i, type(param)), invalid_input_error_code) - allowed_keys = {"value", "column_group"} + allowed_keys = {RequestParameter.VALUE, RequestParameter.COLUMN_GROUP} if set(param.keys()) != allowed_keys: raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENIZE_PARAMETER_KEY.value.format(i), invalid_input_error_code) - if not param.get("value"): - log_error_log(SkyflowMessages.ErrorLogs.COLUMN_VALUES_IS_REQUIRED_TOKENIZE.value.format("TOKENIZE"), logger = logger) + if not param.get(RequestParameter.VALUE): + log_error_log(SkyflowMessages.ErrorLogs.COLUMN_VALUES_IS_REQUIRED_TOKENIZE.value.format(RequestOperation.TOKENIZE), logger = logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENIZE_PARAMETER_VALUE.value.format(i), invalid_input_error_code) - if not param.get("column_group"): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_COLUMN_GROUP_IN_COLUMN_VALUES.value.format("TOKENIZE"), logger = logger) + if not param.get(RequestParameter.COLUMN_GROUP): + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_COLUMN_GROUP_IN_COLUMN_VALUES.value.format(RequestOperation.TOKENIZE), logger = logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENIZE_PARAMETER_COLUMN_GROUP.value.format(i), invalid_input_error_code) @@ -698,32 +763,30 @@ def validate_file_upload_request(logger, request): raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code) # Table - table = getattr(request, "table", None) + table = getattr(request, FileUploadField.TABLE, None) if table is None: raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code) elif table.strip() == "": raise SkyflowError(SkyflowMessages.Error.EMPTY_TABLE_VALUE.value, invalid_input_error_code) # Skyflow ID - skyflow_id = getattr(request, "skyflow_id", None) - if skyflow_id is None: - raise SkyflowError(SkyflowMessages.Error.IDS_KEY_ERROR.value, invalid_input_error_code) - elif skyflow_id.strip() == "": - raise SkyflowError(SkyflowMessages.Error.EMPTY_SKYFLOW_ID.value.format("FILE_UPLOAD"), invalid_input_error_code) + skyflow_id = getattr(request, FileUploadField.SKYFLOW_ID, None) + if skyflow_id is not None and skyflow_id.strip() == "": + raise SkyflowError(SkyflowMessages.Error.EMPTY_SKYFLOW_ID.value.format(RequestOperation.FILE_UPLOAD), invalid_input_error_code) # Column Name - column_name = getattr(request, "column_name", None) + column_name = getattr(request, FileUploadField.COLUMN_NAME, None) if column_name is None: raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_COLUMN_NAME.value.format(type(column_name)), invalid_input_error_code) elif column_name.strip() == "": - logger.error("Empty column name in FILE_UPLOAD") + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_FILE_COLUMN_NAME.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_COLUMN_NAME.value.format(type(column_name)), invalid_input_error_code) # File-related attributes - file_path = getattr(request, "file_path", None) - base64_str = getattr(request, "base64", None) - file_object = getattr(request, "file_object", None) - file_name = getattr(request, "file_name", None) + file_path = getattr(request, FileUploadField.FILE_PATH, None) + base64_str = getattr(request, FileUploadField.BASE64, None) + file_object = getattr(request, FileUploadField.FILE_OBJECT, None) + file_name = getattr(request, FileUploadField.FILE_NAME, None) # Check file_path first if present if not is_none_or_empty(file_path): @@ -775,46 +838,57 @@ def validate_invoke_connection_params(logger, query_params, path_params): except TypeError: raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_PARAMS.value, invalid_input_error_code) -def validate_deidentify_text_request(self, request: DeidentifyTextRequest): +def validate_deidentify_text_request(logger, request: DeidentifyTextRequest): if not request.text or not isinstance(request.text, str) or not request.text.strip(): + log_error_log(SkyflowMessages.Error.INVALID_TEXT_IN_DEIDENTIFY.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TEXT_IN_DEIDENTIFY.value, invalid_input_error_code) # Validate entities if present if request.entities is not None and not isinstance(request.entities, list): + log_error_log(SkyflowMessages.Error.INVALID_ENTITIES_IN_DEIDENTIFY.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_ENTITIES_IN_DEIDENTIFY.value, invalid_input_error_code) # Validate allowed_regex_list if present if request.allow_regex_list is not None and not isinstance(request.allow_regex_list, list): + log_error_log(SkyflowMessages.Error.INVALID_ALLOW_REGEX_LIST.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_ALLOW_REGEX_LIST.value, invalid_input_error_code) # Validate restricted_regex_list if present if request.restrict_regex_list is not None and not isinstance(request.restrict_regex_list, list): + log_error_log(SkyflowMessages.Error.INVALID_RESTRICT_REGEX_LIST.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_RESTRICT_REGEX_LIST.value, invalid_input_error_code) # Validate token_format if present if request.token_format is not None and not isinstance(request.token_format, TokenFormat): + log_error_log(SkyflowMessages.Error.INVALID_TOKEN_FORMAT.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_FORMAT.value, invalid_input_error_code) # Validate transformations if present if request.transformations is not None and not isinstance(request.transformations, Transformations): + log_error_log(SkyflowMessages.Error.INVALID_TRANSFORMATIONS.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TRANSFORMATIONS.value, invalid_input_error_code) -def validate_reidentify_text_request(self, request: ReidentifyTextRequest): +def validate_reidentify_text_request(logger, request: ReidentifyTextRequest): if not request.text or not isinstance(request.text, str) or not request.text.strip(): + log_error_log(SkyflowMessages.Error.INVALID_TEXT_IN_REIDENTIFY.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TEXT_IN_REIDENTIFY.value, invalid_input_error_code) # Validate redacted_entities if present if request.redacted_entities is not None and not isinstance(request.redacted_entities, list): + log_error_log(SkyflowMessages.Error.INVALID_REDACTED_ENTITIES_IN_REIDENTIFY.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_REDACTED_ENTITIES_IN_REIDENTIFY.value, invalid_input_error_code) # Validate masked_entities if present if request.masked_entities is not None and not isinstance(request.masked_entities, list): + log_error_log(SkyflowMessages.Error.INVALID_MASKED_ENTITIES_IN_REIDENTIFY.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_MASKED_ENTITIES_IN_REIDENTIFY.value, invalid_input_error_code) # Validate plain_text_entities if present if request.plain_text_entities is not None and not isinstance(request.plain_text_entities, list): + log_error_log(SkyflowMessages.Error.INVALID_PLAIN_TEXT_ENTITIES_IN_REIDENTIFY.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_PLAIN_TEXT_ENTITIES_IN_REIDENTIFY.value, invalid_input_error_code) -def validate_get_detect_run_request(self, request: GetDetectRunRequest): - if not request.run_id or not isinstance(request.run_id, str) or not request.run_id.strip(): +def validate_get_detect_run_request(logger, request: GetDetectRunRequest): + if request.run_id is None or not isinstance(request.run_id, str) or not request.run_id.strip(): + log_error_log(SkyflowMessages.ErrorLogs.INVALID_RUN_ID.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_RUN_ID.value, invalid_input_error_code) diff --git a/skyflow/vault/client/client.py b/skyflow/vault/client/client.py index f47a525c..8023646c 100644 --- a/skyflow/vault/client/client.py +++ b/skyflow/vault/client/client.py @@ -1,7 +1,9 @@ +from skyflow.error import SkyflowError from skyflow.generated.rest.client import Skyflow from skyflow.service_account import generate_bearer_token, generate_bearer_token_from_creds, is_expired from skyflow.utils import get_vault_url, get_credentials, SkyflowMessages from skyflow.utils.logger import log_info +from skyflow.utils.constants import OptionField, CredentialField, ConfigField class VaultClient: @@ -14,6 +16,9 @@ def __init__(self, config): self.__logger = None self.__is_config_updated = False self.__bearer_token = None + self.__credentials = None + self.__vault_url = None + self.__is_static_token = None def set_common_skyflow_credentials(self, credentials): self.__common_skyflow_credentials = credentials @@ -23,16 +28,27 @@ def set_logger(self, log_level, logger): self.__logger = logger def initialize_client_configuration(self): - credentials = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger = self.__logger) - token = self.get_bearer_token(credentials) - vault_url = get_vault_url(self.__config.get("cluster_id"), - self.__config.get("env"), - self.__config.get("vault_id"), - logger = self.__logger) - self.initialize_api_client(vault_url, token) - - def initialize_api_client(self, vault_url, token): - self.__api_client = Skyflow(base_url=vault_url, token=token) + if self.__api_client is not None and not self.__is_config_updated: + if self.__is_static_token: + return + if self.__bearer_token is not None and not is_expired(self.__bearer_token): + return + + needs_reinit = self.__api_client is None or self.__is_config_updated + if needs_reinit: + self.__credentials = get_credentials(self.__config.get(ConfigField.CREDENTIALS), self.__common_skyflow_credentials, logger=self.__logger) + self.__vault_url = get_vault_url(self.__config.get(ConfigField.CLUSTER_ID), + self.__config.get(ConfigField.ENV), + self.__config.get(ConfigField.VAULT_ID), + logger=self.__logger) + self.__is_static_token = CredentialField.TOKEN in self.__credentials or CredentialField.API_KEY in self.__credentials + bearer_token = self.get_bearer_token(self.__credentials) + if needs_reinit: + self.initialize_api_client(self.__vault_url, bearer_token) + + def initialize_api_client(self, vault_url, bearer_token): + token_provider = lambda: self.__bearer_token if self.__bearer_token is not None else bearer_token # noqa: E731 + self.__api_client = Skyflow(base_url=vault_url, token=token_provider) def get_records_api(self): return self.__api_client.records @@ -50,29 +66,30 @@ def get_detect_file_api(self): return self.__api_client.files def get_vault_id(self): - return self.__config.get("vault_id") + return self.__config.get(ConfigField.VAULT_ID) def get_bearer_token(self, credentials): - if 'api_key' in credentials: - return credentials.get('api_key') - elif 'token' in credentials: - return credentials.get("token") + if CredentialField.API_KEY in credentials: + return credentials.get(CredentialField.API_KEY) + elif CredentialField.TOKEN in credentials: + return credentials.get(CredentialField.TOKEN) options = { - "role_ids": self.__config.get("roles"), - "ctx": self.__config.get("ctx") + OptionField.ROLE_IDS: self.__config.get(OptionField.ROLES), + OptionField.CTX: self.__config.get(OptionField.CTX) } + if CredentialField.TOKEN_URI_OPTION in credentials and credentials.get(CredentialField.TOKEN_URI_OPTION): + options[CredentialField.TOKEN_URI_OPTION] = credentials.get(CredentialField.TOKEN_URI_OPTION) - if self.__bearer_token is None or self.__is_config_updated: - if 'path' in credentials: - path = credentials.get("path") + if self.__bearer_token is None or self.__is_config_updated or is_expired(self.__bearer_token): + if CredentialField.PATH in credentials: self.__bearer_token, _ = generate_bearer_token( - path, + credentials.get(CredentialField.PATH), options, self.__logger ) else: - credentials_string = credentials.get('credentials_string') + credentials_string = credentials.get(CredentialField.CREDENTIALS_STRING) log_info(SkyflowMessages.Info.GENERATE_BEARER_TOKEN_FROM_CREDENTIALS_STRING_TRIGGERED.value, self.__logger) self.__bearer_token, _ = generate_bearer_token_from_creds( credentials_string, @@ -83,10 +100,6 @@ def get_bearer_token(self, credentials): else: log_info(SkyflowMessages.Info.REUSE_BEARER_TOKEN.value, self.__logger) - if is_expired(self.__bearer_token): - self.__is_config_updated = True - raise SyntaxError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) - return self.__bearer_token def update_config(self, config): diff --git a/skyflow/vault/controller/_connections.py b/skyflow/vault/controller/_connections.py index 81c6ea10..2ce0c104 100644 --- a/skyflow/vault/controller/_connections.py +++ b/skyflow/vault/controller/_connections.py @@ -5,6 +5,8 @@ parse_invoke_connection_response from skyflow.utils.logger import log_info, log_error_log from skyflow.vault.connection import InvokeConnectionRequest +from skyflow.utils.constants import SKY_META_DATA_HEADER, SKYFLOW, HttpHeader, OptionField, ConfigField +from skyflow.utils import get_credentials class Connection: @@ -12,20 +14,22 @@ def __init__(self, vault_client): self.__vault_client = vault_client def invoke(self, request: InvokeConnectionRequest): - session = requests.Session() - - config = self.__vault_client.get_config() - bearer_token = self.__vault_client.get_bearer_token(config.get("credentials")) - - connection_url = config.get("connection_url") log_info(SkyflowMessages.Info.VALIDATING_INVOKE_CONNECTION_REQUEST.value, self.__vault_client.get_logger()) + config = self.__vault_client.get_config() + connection_url = config.get(OptionField.CONNECTION_URL) invoke_connection_request = construct_invoke_connection_request(request, connection_url, self.__vault_client.get_logger()) log_info(SkyflowMessages.Info.INVOKE_CONNECTION_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) + + credentials = get_credentials(config.get(ConfigField.CREDENTIALS), self.__vault_client.get_common_skyflow_credentials(), self.__vault_client.get_logger()) + + bearer_token = self.__vault_client.get_bearer_token(credentials) + + session = requests.Session() - if not 'X-Skyflow-Authorization'.lower() in invoke_connection_request.headers: - invoke_connection_request.headers['x-skyflow-authorization'] = bearer_token + if not HttpHeader.X_SKYFLOW_AUTHORIZATION_HEADER.lower() in invoke_connection_request.headers: + invoke_connection_request.headers[SKYFLOW.X_SKYFLOW_AUTHORIZATION] = bearer_token - invoke_connection_request.headers['sky-metadata'] = json.dumps(get_metrics()) + invoke_connection_request.headers[SKY_META_DATA_HEADER] = json.dumps(get_metrics()) log_info(SkyflowMessages.Info.INVOKE_CONNECTION_TRIGGERED.value, self.__vault_client.get_logger()) diff --git a/skyflow/vault/controller/_detect.py b/skyflow/vault/controller/_detect.py index 44ef2540..f12b6215 100644 --- a/skyflow/vault/controller/_detect.py +++ b/skyflow/vault/controller/_detect.py @@ -8,7 +8,8 @@ FileDataDeidentifyImage, Format, FileDataDeidentifyAudio, WordCharacterCount, DetectRunsResponse from skyflow.utils._skyflow_messages import SkyflowMessages from skyflow.utils._utils import get_attribute, get_metrics, handle_exception, parse_deidentify_text_response, parse_reidentify_text_response -from skyflow.utils.constants import SKY_META_DATA_HEADER +from skyflow.utils.constants import (SKY_META_DATA_HEADER, DetectStatus, FileExtension, + FileProcessing, EncodingType, DeidentifyField, DeidentifyFileRequestField, FileUploadField, OptionField, Detect as DetectConstants) from skyflow.utils.logger import log_info, log_error_log from skyflow.utils.validations import validate_deidentify_file_request, validate_get_detect_run_request from skyflow.utils.validations._validations import validate_deidentify_text_request, validate_reidentify_text_request @@ -29,44 +30,44 @@ def __get_headers(self): } return headers - def ___build_deidentify_text_body(self, request: DeidentifyTextRequest) -> Dict[str, Any]: + def __build_deidentify_text_body(self, request: DeidentifyTextRequest) -> Dict[str, Any]: deidentify_text_body = {} parsed_entity_types = request.entities - deidentify_text_body['text'] = request.text - deidentify_text_body['entity_types'] = parsed_entity_types - deidentify_text_body['token_type'] = self.__get_token_format(request) - deidentify_text_body['allow_regex'] = request.allow_regex_list - deidentify_text_body['restrict_regex'] = request.restrict_regex_list - deidentify_text_body['transformations'] = self.__get_transformations(request) + deidentify_text_body[DeidentifyField.TEXT] = request.text + deidentify_text_body[DeidentifyField.ENTITY_TYPES] = parsed_entity_types + deidentify_text_body[DeidentifyField.TOKEN_TYPE] = self.__get_token_format(request) + deidentify_text_body[DeidentifyField.ALLOW_REGEX] = request.allow_regex_list + deidentify_text_body[DeidentifyField.RESTRICT_REGEX] = request.restrict_regex_list + deidentify_text_body[DeidentifyField.TRANSFORMATIONS] = self.__get_transformations(request) return deidentify_text_body - def ___build_reidentify_text_body(self, request: ReidentifyTextRequest) -> Dict[str, Any]: + def __build_reidentify_text_body(self, request: ReidentifyTextRequest) -> Dict[str, Any]: parsed_format = Format( redacted=request.redacted_entities, masked=request.masked_entities, plaintext=request.plain_text_entities ) reidentify_text_body = {} - reidentify_text_body['text'] = request.text - reidentify_text_body['format'] = parsed_format + reidentify_text_body[DeidentifyField.TEXT] = request.text + reidentify_text_body[DeidentifyField.FORMAT] = parsed_format return reidentify_text_body def _get_file_extension(self, filename: str): return filename.split('.')[-1].lower() if '.' in filename else '' - def __poll_for_processed_file(self, run_id, max_wait_time=64): - max_wait_time = 64 if max_wait_time is None else max_wait_time + def __poll_for_processed_file(self, run_id, max_wait_time=None): + max_wait_time = DetectConstants.WAIT_TIME if max_wait_time is None else max_wait_time files_api = self.__vault_client.get_detect_file_api().with_raw_response current_wait_time = 1 # Start with 1 second try: while True: - response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options=self.__get_headers()).data + response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options={'additional_headers': self.__get_headers()}).data status = response.status - if status == 'IN_PROGRESS': + if status == DetectStatus.IN_PROGRESS: if current_wait_time >= max_wait_time: - return DeidentifyFileResponse(run_id=run_id, status='IN_PROGRESS') + return DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS) else: next_wait_time = current_wait_time * 2 if next_wait_time >= max_wait_time: @@ -76,42 +77,54 @@ def __poll_for_processed_file(self, run_id, max_wait_time=64): wait_time = next_wait_time current_wait_time = next_wait_time time.sleep(wait_time) - elif status == 'SUCCESS' or status == 'FAILED': + elif status == DetectStatus.SUCCESS or status == DetectStatus.FAILED: return response except Exception as e: - raise e + handle_exception(e, self.__vault_client.get_logger()) def __save_deidentify_file_response_output(self, response: DetectRunsResponse, output_directory: str, original_file_name: str, name_without_ext: str): - if not response or not hasattr(response, 'output') or not response.output or not output_directory: + if not response or not hasattr(response, DeidentifyField.OUTPUT) or not response.output or not output_directory: return if not os.path.exists(output_directory): return - deidentify_file_prefix = "processed-" + deidentify_file_prefix = FileProcessing.PROCESSED_PREFIX output_list = response.output base_original_filename = os.path.basename(original_file_name) base_name_without_ext = os.path.splitext(base_original_filename)[0] + real_output_dir = os.path.realpath(output_directory) for idx, output in enumerate(output_list): try: - processed_file = get_attribute(output, 'processedFile', 'processed_file') - processed_file_type = get_attribute(output, 'processedFileType', 'processed_file_type') - processed_file_extension = get_attribute(output, 'processedFileExtension', 'processed_file_extension') + processed_file = get_attribute(output, DeidentifyField.PROCESSED_FILE_RESPONSE_KEY, DeidentifyField.PROCESSED_FILE) + processed_file_type = get_attribute(output, DeidentifyField.PROCESSED_FILE_TYPE_RESPONSE_KEY, DeidentifyField.PROCESSED_FILE_TYPE) + processed_file_extension = get_attribute(output, DeidentifyField.PROCESSED_FILE_EXTENSION_RESPONSE_KEY, DeidentifyField.PROCESSED_FILE_EXTENSION) if not processed_file: continue decoded_data = base64.b64decode(processed_file) - - if idx == 0 or processed_file_type == 'redacted_file': + + # Sanitize extension from API response to prevent path traversal (CWE-22). + # Avoid os.path.basename here to keep basename mock-free in tests. + safe_ext = None + if processed_file_extension: + raw_ext = str(processed_file_extension).replace('\\', '/').split('/')[-1].lstrip('.') + safe_ext = ''.join(c for c in raw_ext if c.isalnum() or c in ('-', '_')) or 'bin' + + if idx == 0 or processed_file_type == DeidentifyField.REDACTED_FILE: output_file_name = os.path.join(output_directory, deidentify_file_prefix + base_original_filename) - if processed_file_extension: - output_file_name = os.path.join(output_directory, f"{deidentify_file_prefix}{base_name_without_ext}.{processed_file_extension}") + if safe_ext: + output_file_name = os.path.join(output_directory, f"{deidentify_file_prefix}{base_name_without_ext}.{safe_ext}") else: - output_file_name = os.path.join(output_directory, f"{deidentify_file_prefix}{base_name_without_ext}.{processed_file_extension}") - + output_file_name = os.path.join(output_directory, f"{deidentify_file_prefix}{base_name_without_ext}.{safe_ext or 'bin'}") + + if not os.path.realpath(output_file_name).startswith(real_output_dir + os.sep): + log_error_log(SkyflowMessages.ErrorLogs.SAVING_DEIDENTIFY_FILE_FAILED.value, self.__vault_client.get_logger()) + continue + with open(output_file_name, 'wb') as f: f.write(decoded_data) except Exception as e: @@ -119,62 +132,62 @@ def __save_deidentify_file_response_output(self, response: DetectRunsResponse, o handle_exception(e, self.__vault_client.get_logger()) def __parse_deidentify_file_response(self, data, run_id=None, status=None): - output = getattr(data, "output", []) - status_val = getattr(data, "status", None) or status - run_id_val = getattr(data, "run_id", None) or run_id + output = getattr(data, DeidentifyField.OUTPUT, []) + status_val = getattr(data, DeidentifyField.STATUS, None) or status + run_id_val = getattr(data, DeidentifyField.RUN_ID, None) or run_id word_count = None char_count = None - word_character_count = getattr(data, "word_character_count", None) + word_character_count = getattr(data, DeidentifyField.WORD_CHARACTER_COUNT, None) if word_character_count and isinstance(word_character_count, WordCharacterCount): - word_count = word_character_count.word_count - char_count = word_character_count.character_count + word_count = getattr(word_character_count, DeidentifyField.WORD_COUNT, None) + char_count = getattr(word_character_count, DeidentifyField.CHARACTER_COUNT, None) - size = getattr(data, "size", None) + size = getattr(data, DeidentifyField.SIZE, None) size = float(size) if size is not None else None - duration = getattr(data, "duration", None) - pages = getattr(data, "pages", None) - slides = getattr(data, "slides", None) + duration = getattr(data, DeidentifyField.DURATION, None) + pages = getattr(data, DeidentifyField.PAGES, None) + slides = getattr(data, DeidentifyField.SLIDES, None) def output_to_dict_list(output): result = [] for o in output: if isinstance(o, dict): result.append({ - "file": o.get("processed_file"), - "type": o.get("processed_file_type"), - "extension": o.get("processed_file_extension") + DeidentifyField.FILE: o.get(DeidentifyField.PROCESSED_FILE), + DeidentifyField.TYPE: o.get(DeidentifyField.PROCESSED_FILE_TYPE), + DeidentifyField.EXTENSION: o.get(DeidentifyField.PROCESSED_FILE_EXTENSION) }) else: result.append({ - "file": getattr(o, "processed_file", None), - "type": getattr(o, "processed_file_type", None), - "extension": getattr(o, "processed_file_extension", None) + DeidentifyField.FILE: getattr(o, DeidentifyField.PROCESSED_FILE, None), + DeidentifyField.TYPE: getattr(o, DeidentifyField.PROCESSED_FILE_TYPE, None), + DeidentifyField.EXTENSION: getattr(o, DeidentifyField.PROCESSED_FILE_EXTENSION, None) }) return result output_list = output_to_dict_list(output) first_output = output_list[0] if output_list else {} - entities = [o for o in output_list if o.get("type") == "entities"] + entities = [o for o in output_list if o.get(DeidentifyField.TYPE) == FileProcessing.ENTITIES] - base64_string = first_output.get("file", None) - extension = first_output.get("extension", None) + base64_string = first_output.get(DeidentifyField.FILE, None) + extension = first_output.get(DeidentifyField.EXTENSION, None) if base64_string is not None: - file_bytes = base64.b64decode(base64_string) - file_obj = io.BytesIO(file_bytes) - file_obj.name = f"deidentified.{extension}" if extension else "processed_file" + file_bytes = base64.b64decode(base64_string) + file_obj = io.BytesIO(file_bytes) + file_obj.name = f"{FileProcessing.DEIDENTIFIED_PREFIX}{extension}" if extension else DeidentifyField.PROCESSED_FILE else: file_obj = None return DeidentifyFileResponse( file_base64=base64_string, file=file_obj, - type=first_output.get("type", "UNKNOWN"), + type=first_output.get(DeidentifyField.TYPE, None), extension=extension, word_count=word_count, char_count=char_count, @@ -188,25 +201,26 @@ def output_to_dict_list(output): ) def __get_token_format(self, request): - if not hasattr(request, "token_format") or request.token_format is None: + if not hasattr(request, DeidentifyField.TOKEN_FORMAT) or request.token_format is None: return None return { - 'default': getattr(request.token_format, "default", None), - 'entity_unq_counter': getattr(request.token_format, "entity_unique_counter", None), - 'entity_only': getattr(request.token_format, "entity_only", None), + DeidentifyField.DEFAULT: getattr(request.token_format, DeidentifyField.DEFAULT, None), + DeidentifyField.ENTITY_UNQ_COUNTER: getattr(request.token_format, DeidentifyField.ENTITY_UNIQUE_COUNTER, None), + DeidentifyField.ENTITY_ONLY: getattr(request.token_format, DeidentifyField.ENTITY_ONLY, None), + DeidentifyField.VAULT_TOKEN: getattr(request.token_format, DeidentifyField.VAULT_TOKEN, None) } def __get_transformations(self, request): - if not hasattr(request, "transformations") or request.transformations is None: + if not hasattr(request, DeidentifyField.TRANSFORMATIONS) or request.transformations is None: return None - shift_dates = getattr(request.transformations, "shift_dates", None) + shift_dates = getattr(request.transformations, DeidentifyField.SHIFT_DATES, None) if shift_dates is None: return None return { - 'shift_dates': { - 'max_days': getattr(shift_dates, "max", None), - 'min_days': getattr(shift_dates, "min", None), - 'entity_types': getattr(shift_dates, "entities", None) + DeidentifyField.SHIFT_DATES: { + DeidentifyField.MAX_DAYS: getattr(shift_dates, DeidentifyField.MAX, None), + DeidentifyField.MIN_DAYS: getattr(shift_dates, DeidentifyField.MIN, None), + DeidentifyField.ENTITY_TYPES: getattr(shift_dates, DeidentifyField.ENTITIES, None) } } @@ -216,19 +230,19 @@ def deidentify_text(self, request: DeidentifyTextRequest) -> DeidentifyTextRespo log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) self.__initialize() detect_api = self.__vault_client.get_detect_text_api() - deidentify_text_body = self.___build_deidentify_text_body(request) + deidentify_text_body = self.__build_deidentify_text_body(request) try: log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_TRIGGERED.value, self.__vault_client.get_logger()) api_response = detect_api.deidentify_string( vault_id=self.__vault_client.get_vault_id(), - text=deidentify_text_body['text'], - entity_types=deidentify_text_body['entity_types'], - allow_regex=deidentify_text_body['allow_regex'], - restrict_regex=deidentify_text_body['restrict_regex'], - token_type=deidentify_text_body['token_type'], - transformations=deidentify_text_body['transformations'], - request_options=self.__get_headers() + text=deidentify_text_body[DeidentifyField.TEXT], + entity_types=deidentify_text_body[DeidentifyField.ENTITY_TYPES], + allow_regex=deidentify_text_body[DeidentifyField.ALLOW_REGEX], + restrict_regex=deidentify_text_body[DeidentifyField.RESTRICT_REGEX], + token_type=deidentify_text_body[DeidentifyField.TOKEN_TYPE], + transformations=deidentify_text_body[DeidentifyField.TRANSFORMATIONS], + request_options={'additional_headers': self.__get_headers()} ) deidentify_text_response = parse_deidentify_text_response(api_response) log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger()) @@ -244,15 +258,15 @@ def reidentify_text(self, request: ReidentifyTextRequest) -> ReidentifyTextRespo log_info(SkyflowMessages.Info.REIDENTIFY_TEXT_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) self.__initialize() detect_api = self.__vault_client.get_detect_text_api() - reidentify_text_body = self.___build_reidentify_text_body(request) + reidentify_text_body = self.__build_reidentify_text_body(request) try: log_info(SkyflowMessages.Info.REIDENTIFY_TEXT_TRIGGERED.value, self.__vault_client.get_logger()) api_response = detect_api.reidentify_string( vault_id=self.__vault_client.get_vault_id(), - text=reidentify_text_body['text'], - format=reidentify_text_body['format'], - request_options=self.__get_headers() + text=reidentify_text_body[DeidentifyField.TEXT], + format=reidentify_text_body[DeidentifyField.FORMAT], + request_options={'additional_headers': self.__get_headers()} ) reidentify_text_response = parse_reidentify_text_response(api_response) log_info(SkyflowMessages.Info.REIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger()) @@ -264,14 +278,16 @@ def reidentify_text(self, request: ReidentifyTextRequest) -> ReidentifyTextRespo def __get_file_from_request(self, request: DeidentifyFileRequest): file_input = request.file - - # Check for file - if hasattr(file_input, 'file') and file_input.file is not None: + + if hasattr(file_input, FileUploadField.FILE) and file_input.file is not None: return file_input.file - - # Check for file_path if file is not provided - if hasattr(file_input, 'file_path') and file_input.file_path is not None: - return open(file_input.file_path, 'rb') + + if hasattr(file_input, FileUploadField.FILE_PATH) and file_input.file_path is not None: + with open(file_input.file_path, 'rb') as f: + content = f.read() + bio = io.BytesIO(content) + bio.name = file_input.file_path + return bio def deidentify_file(self, request: DeidentifyFileRequest): log_info(SkyflowMessages.Info.DETECT_FILE_TRIGGERED.value, self.__vault_client.get_logger()) @@ -279,151 +295,152 @@ def deidentify_file(self, request: DeidentifyFileRequest): self.__initialize() files_api = self.__vault_client.get_detect_file_api().with_raw_response file_obj = self.__get_file_from_request(request) - file_name = getattr(file_obj, 'name', None) + file_name = getattr(file_obj, FileUploadField.NAME, None) file_extension = self._get_file_extension(file_name) if file_name else None file_content = file_obj.read() - base64_string = base64.b64encode(file_content).decode('utf-8') + base64_string = base64.b64encode(file_content).decode(EncodingType.UTF_8) try: - if file_extension == 'txt': - req_file = FileDataDeidentifyText(base_64=base64_string, data_format="txt") + if file_extension == FileExtension.TXT: + req_file = FileDataDeidentifyText(base_64=base64_string, data_format=FileExtension.TXT) api_call = files_api.deidentify_text api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'transformations': self.__get_transformations(request), - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } - elif file_extension in ['mp3', 'wav']: + elif file_extension in [FileExtension.MP3, FileExtension.WAV]: req_file = FileDataDeidentifyAudio(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_audio + bleep = request.bleep api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'transformations': self.__get_transformations(request), - 'output_transcription': getattr(request, 'output_transcription', None), - 'output_processed_audio': getattr(request, 'output_processed_audio', None), - 'bleep_gain': getattr(request, 'bleep', None).gain if getattr(request, 'bleep', None) is not None else None, - 'bleep_frequency': getattr(request, 'bleep', None).frequency if getattr(request, 'bleep', None) is not None else None, - 'bleep_start_padding': getattr(request, 'bleep', None).start_padding if getattr(request, 'bleep', None) is not None else None, - 'bleep_stop_padding': getattr(request, 'bleep', None).stop_padding if getattr(request, 'bleep', None) is not None else None, - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), + DeidentifyFileRequestField.OUTPUT_TRANSCRIPTION: getattr(request, DeidentifyFileRequestField.OUTPUT_TRANSCRIPTION, None), + DeidentifyFileRequestField.OUTPUT_PROCESSED_AUDIO: getattr(request, DeidentifyFileRequestField.OUTPUT_PROCESSED_AUDIO, None), + DeidentifyField.BLEEP_GAIN: bleep.gain if bleep is not None else None, + DeidentifyField.BLEEP_FREQUENCY: bleep.frequency if bleep is not None else None, + DeidentifyField.BLEEP_START_PADDING: bleep.start_padding if bleep is not None else None, + DeidentifyField.BLEEP_STOP_PADDING: bleep.stop_padding if bleep is not None else None, + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } - elif file_extension == 'pdf': + elif file_extension == FileExtension.PDF: req_file = FileDataDeidentifyPdf(base_64=base64_string) api_call = files_api.deidentify_pdf api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'max_resolution': getattr(request, 'max_resolution', None), - 'density': getattr(request, 'pixel_density', None), - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyFileRequestField.MAX_RESOLUTION: getattr(request, DeidentifyFileRequestField.MAX_RESOLUTION, None), + DeidentifyFileRequestField.DENSITY: getattr(request, DeidentifyFileRequestField.PIXEL_DENSITY, None), + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } - elif file_extension in ['jpeg', 'jpg', 'png', 'bmp', 'tif', 'tiff']: + elif file_extension in [FileExtension.JPEG, FileExtension.JPG, FileExtension.PNG, FileExtension.BMP, FileExtension.TIF, FileExtension.TIFF]: req_file = FileDataDeidentifyImage(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_image api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'masking_method': getattr(request, 'masking_method', None), - 'output_ocr_text': getattr(request, 'output_ocr_text', None), - 'output_processed_image': getattr(request, 'output_processed_image', None), - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyFileRequestField.MASKING_METHOD: getattr(request, DeidentifyFileRequestField.MASKING_METHOD, None), + DeidentifyFileRequestField.OUTPUT_OCR_TEXT: getattr(request, DeidentifyFileRequestField.OUTPUT_OCR_TEXT, None), + DeidentifyFileRequestField.OUTPUT_PROCESSED_IMAGE: getattr(request, DeidentifyFileRequestField.OUTPUT_PROCESSED_IMAGE, None), + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } - elif file_extension in ['ppt', 'pptx']: + elif file_extension in [FileExtension.PPT, FileExtension.PPTX]: req_file = FileDataDeidentifyPresentation(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_presentation api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } - elif file_extension in ['csv', 'xls', 'xlsx']: + elif file_extension in [FileExtension.CSV, FileExtension.XLS, FileExtension.XLSX]: req_file = FileDataDeidentifySpreadsheet(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_spreadsheet api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } - elif file_extension in ['doc', 'docx']: + elif file_extension in [FileExtension.DOC, FileExtension.DOCX]: req_file = FileDataDeidentifyDocument(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_document api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } - elif file_extension in ['json', 'xml']: + elif file_extension in [FileExtension.JSON, FileExtension.XML]: req_file = FileDataDeidentifyStructuredText(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_structured_text api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'transformations': self.__get_transformations(request), - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } else: req_file = FileData(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_file api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'transformations': self.__get_transformations(request), - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } log_info(SkyflowMessages.Info.DETECT_FILE_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) api_response = api_call(**api_kwargs) - run_id = getattr(api_response.data, 'run_id', None) + run_id = getattr(api_response.data, DeidentifyField.RUN_ID, None) processed_response = self.__poll_for_processed_file(run_id, request.wait_time) - if request.output_directory and processed_response.status == 'SUCCESS': + if request.output_directory and processed_response.status == DetectStatus.SUCCESS and file_name: name_without_ext, _ = os.path.splitext(file_name) self.__save_deidentify_file_response_output(processed_response, request.output_directory, file_name, name_without_ext) @@ -448,10 +465,10 @@ def get_detect_run(self, request: GetDetectRunRequest): response = files_api.get_run( run_id, vault_id=self.__vault_client.get_vault_id(), - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) - if response.data.status == 'IN_PROGRESS': - parsed_response = self.__parse_deidentify_file_response(DeidentifyFileResponse(run_id=run_id, status='IN_PROGRESS')) + if response.data.status == DetectStatus.IN_PROGRESS: + parsed_response = DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS) else: parsed_response = self.__parse_deidentify_file_response(response.data, run_id, response.data.status) log_info(SkyflowMessages.Info.GET_DETECT_RUN_SUCCESS.value,self.__vault_client.get_logger()) diff --git a/skyflow/vault/controller/_vault.py b/skyflow/vault/controller/_vault.py index 7cc9ec77..fd085e35 100644 --- a/skyflow/vault/controller/_vault.py +++ b/skyflow/vault/controller/_vault.py @@ -8,7 +8,7 @@ from skyflow.utils import SkyflowMessages, parse_insert_response, \ handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, \ parse_tokenize_response, parse_query_response, parse_get_response, encode_column_values, get_metrics -from skyflow.utils.constants import SKY_META_DATA_HEADER +from skyflow.utils.constants import SKY_META_DATA_HEADER, ResponseField, RequestParameter, FileUploadField from skyflow.utils.enums import RequestMethod from skyflow.utils.enums.redaction_type import RedactionType from skyflow.utils.logger import log_info, log_error_log @@ -82,17 +82,14 @@ def __get_file_for_file_upload(self, request: FileUploadRequest) -> Optional[Fil return (request.file_name, decoded_bytes) elif request.file_object is not None: - if hasattr(request.file_object, "name") and request.file_object.name: + if hasattr(request.file_object, FileUploadField.NAME) and request.file_object.name: file_name = os.path.basename(request.file_object.name) return (file_name, request.file_object) return None def __get_headers(self): - headers = { - SKY_META_DATA_HEADER: json.dumps(get_metrics()) - } - return headers + return {SKY_META_DATA_HEADER: json.dumps(get_metrics())} def insert(self, request: InsertRequest): log_info(SkyflowMessages.Info.VALIDATE_INSERT_REQUEST.value, self.__vault_client.get_logger()) @@ -106,11 +103,11 @@ def insert(self, request: InsertRequest): log_info(SkyflowMessages.Info.INSERT_TRIGGERED.value, self.__vault_client.get_logger()) if request.continue_on_error: api_response = records_api.record_service_batch_operation(self.__vault_client.get_vault_id(), - records=insert_body, continue_on_error=request.continue_on_error, byot=request.token_mode.value, request_options=self.__get_headers()) + records=insert_body, continue_on_error=request.continue_on_error, byot=request.token_mode.value, request_options={'additional_headers': self.__get_headers()}) else: api_response = records_api.record_service_insert_record(self.__vault_client.get_vault_id(), - request.table, records=insert_body,tokenization= request.return_tokens, upsert=request.upsert, homogeneous=request.homogeneous, byot=request.token_mode.value, request_options=self.__get_headers()) + request.table, records=insert_body,tokenization= request.return_tokens, upsert=request.upsert, homogeneous=request.homogeneous, byot=request.token_mode.value, request_options={'additional_headers': self.__get_headers()}) insert_response = parse_insert_response(api_response, request.continue_on_error) log_info(SkyflowMessages.Info.INSERT_SUCCESS.value, self.__vault_client.get_logger()) @@ -125,7 +122,7 @@ def update(self, request: UpdateRequest): validate_update_request(self.__vault_client.get_logger(), request) log_info(SkyflowMessages.Info.UPDATE_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) self.__initialize() - field = {key: value for key, value in request.data.items() if key != "skyflow_id"} + field = {key: value for key, value in request.data.items() if key != ResponseField.SKYFLOW_ID} record = V1FieldRecords(fields=field, tokens = request.tokens) records_api = self.__vault_client.get_records_api() @@ -134,11 +131,11 @@ def update(self, request: UpdateRequest): api_response = records_api.record_service_update_record( self.__vault_client.get_vault_id(), request.table, - id=request.data.get("skyflow_id"), + id=request.data.get(ResponseField.SKYFLOW_ID), record=record, tokenization=request.return_tokens, byot=request.token_mode.value, - request_options = self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.UPDATE_SUCCESS.value, self.__vault_client.get_logger()) update_response = parse_update_record_response(api_response) @@ -159,7 +156,7 @@ def delete(self, request: DeleteRequest): self.__vault_client.get_vault_id(), request.table, skyflow_ids=request.ids, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.DELETE_SUCCESS.value, self.__vault_client.get_logger()) delete_response = parse_delete_response(api_response) @@ -189,7 +186,7 @@ def get(self, request: GetRequest): download_url=request.download_url, column_name=request.column_name, column_values=request.column_values, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.GET_SUCCESS.value, self.__vault_client.get_logger()) get_response = parse_get_response(api_response) @@ -209,7 +206,7 @@ def query(self, request: QueryRequest): api_response = query_api.query_service_execute_query( self.__vault_client.get_vault_id(), query=request.query, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.QUERY_SUCCESS.value, self.__vault_client.get_logger()) query_response = parse_query_response(api_response) @@ -225,8 +222,8 @@ def detokenize(self, request: DetokenizeRequest): self.__initialize() tokens_list = [ V1DetokenizeRecordRequest( - token=item.get('token'), - redaction=item.get('redaction', RedactionType.DEFAULT) + token=item.get(ResponseField.TOKEN), + redaction=item.get(RequestParameter.REDACTION_TYPE) or item.get(RequestParameter.REDACTION, RedactionType.DEFAULT) ) for item in request.data ] @@ -237,7 +234,7 @@ def detokenize(self, request: DetokenizeRequest): self.__vault_client.get_vault_id(), detokenization_parameters=tokens_list, continue_on_error = request.continue_on_error, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.DETOKENIZE_SUCCESS.value, self.__vault_client.get_logger()) detokenize_response = parse_detokenize_response(api_response) @@ -253,7 +250,7 @@ def tokenize(self, request: TokenizeRequest): self.__initialize() records_list = [ - V1TokenizeRecordRequest(value=item["value"], column_group=item["column_group"]) + V1TokenizeRecordRequest(value=item[RequestParameter.VALUE], column_group=item[RequestParameter.COLUMN_GROUP]) for item in request.values ] tokens_api = self.__vault_client.get_tokens_api() @@ -262,7 +259,7 @@ def tokenize(self, request: TokenizeRequest): api_response = tokens_api.record_service_tokenize( self.__vault_client.get_vault_id(), tokenization_parameters=records_list, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) tokenize_response = parse_tokenize_response(api_response) log_info(SkyflowMessages.Info.TOKENIZE_SUCCESS.value, self.__vault_client.get_logger()) @@ -285,7 +282,7 @@ def upload_file(self, request: FileUploadRequest): file=self.__get_file_for_file_upload(request), skyflow_id=request.skyflow_id, return_file_metadata= False, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.FILE_UPLOAD_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) log_info(SkyflowMessages.Info.FILE_UPLOAD_SUCCESS.value, self.__vault_client.get_logger()) diff --git a/skyflow/vault/data/_file_upload_request.py b/skyflow/vault/data/_file_upload_request.py index d1bd4a44..c5c08b51 100644 --- a/skyflow/vault/data/_file_upload_request.py +++ b/skyflow/vault/data/_file_upload_request.py @@ -1,14 +1,23 @@ -from typing import BinaryIO +from typing import BinaryIO, Optional + +from skyflow.utils import SkyflowMessages +from skyflow.utils.logger import log_warn + class FileUploadRequest: def __init__(self, table: str, - skyflow_id: str, - column_name: str, - file_path: str= None, - base64: str= None, - file_object: BinaryIO= None, - file_name: str= None): + *args, + column_name: Optional[str] = None, + skyflow_id: Optional[str] = None, + file_path: Optional[str] = None, + base64: Optional[str] = None, + file_object: Optional[BinaryIO] = None, + file_name: Optional[str] = None): + if args: + log_warn(SkyflowMessages.Warning.FILE_UPLOAD_REQUEST_ARG_ORDER_DEPRECATED.value) + skyflow_id = args[0] if args else skyflow_id + column_name = args[1] if len(args) > 1 else column_name self.table = table self.skyflow_id = skyflow_id self.column_name = column_name diff --git a/skyflow/vault/data/_get_response.py b/skyflow/vault/data/_get_response.py index cf1b0805..a1640254 100644 --- a/skyflow/vault/data/_get_response.py +++ b/skyflow/vault/data/_get_response.py @@ -1,6 +1,6 @@ class GetResponse: def __init__(self, data=None, errors = None): - self.data = data if data else [] + self.data = data if data is not None else [] self.errors = errors def __repr__(self): diff --git a/skyflow/vault/detect/_deidentify_file_response.py b/skyflow/vault/detect/_deidentify_file_response.py index b340e21c..e56f2113 100644 --- a/skyflow/vault/detect/_deidentify_file_response.py +++ b/skyflow/vault/detect/_deidentify_file_response.py @@ -1,22 +1,24 @@ import io +from typing import Optional from skyflow.vault.detect._file import File class DeidentifyFileResponse: def __init__( self, - file_base64: str = None, - file: io.BytesIO = None, - type: str = None, - extension: str = None, - word_count: int = None, - char_count: int = None, - size_in_kb: float = None, - duration_in_seconds: float = None, - page_count: int = None, - slide_count: int = None, - entities: list = None, # list of dicts with keys 'file' and 'extension' - run_id: str = None, - status: str = None, + file_base64: Optional[str] = None, + file: Optional[io.BytesIO] = None, + type: Optional[str] = None, + extension: Optional[str] = None, + word_count: Optional[int] = None, + char_count: Optional[int] = None, + size_in_kb: Optional[float] = None, + duration_in_seconds: Optional[float] = None, + page_count: Optional[int] = None, + slide_count: Optional[int] = None, + entities: Optional[list] = None, + run_id: Optional[str] = None, + status: Optional[str] = None, + errors: Optional[list] = None, ): self.file_base64 = file_base64 self.file = File(file) if file else None @@ -31,6 +33,7 @@ def __init__( self.entities = entities if entities is not None else [] self.run_id = run_id self.status = status + self.errors = errors def __repr__(self): return ( @@ -40,7 +43,7 @@ def __repr__(self): f"char_count={self.char_count!r}, size_in_kb={self.size_in_kb!r}, " f"duration_in_seconds={self.duration_in_seconds!r}, page_count={self.page_count!r}, " f"slide_count={self.slide_count!r}, entities={self.entities!r}, " - f"run_id={self.run_id!r}, status={self.status!r})" + f"run_id={self.run_id!r}, status={self.status!r}, errors={self.errors!r})" ) def __str__(self): diff --git a/skyflow/vault/detect/_deidentify_text_response.py b/skyflow/vault/detect/_deidentify_text_response.py index cdb6632e..227b43bc 100644 --- a/skyflow/vault/detect/_deidentify_text_response.py +++ b/skyflow/vault/detect/_deidentify_text_response.py @@ -1,19 +1,21 @@ -from typing import List +from typing import List, Optional from ._entity_info import EntityInfo class DeidentifyTextResponse: - def __init__(self, + def __init__(self, processed_text: str, - entities: List[EntityInfo], + entities: List[EntityInfo], word_count: int, - char_count: int): + char_count: int, + errors: Optional[list] = None): self.processed_text = processed_text self.entities = entities self.word_count = word_count self.char_count = char_count + self.errors = errors def __repr__(self): - return f"DeidentifyTextResponse(processed_text='{self.processed_text}', entities={self.entities}, word_count={self.word_count}, char_count={self.char_count})" + return f"DeidentifyTextResponse(processed_text='{self.processed_text}', entities={self.entities}, word_count={self.word_count}, char_count={self.char_count}, errors={self.errors})" def __str__(self): return self.__repr__() \ No newline at end of file diff --git a/skyflow/vault/detect/_reidentify_text_response.py b/skyflow/vault/detect/_reidentify_text_response.py index 50c3876d..73ad3f5d 100644 --- a/skyflow/vault/detect/_reidentify_text_response.py +++ b/skyflow/vault/detect/_reidentify_text_response.py @@ -1,9 +1,12 @@ +from typing import Optional + class ReidentifyTextResponse: - def __init__(self, processed_text: str): + def __init__(self, processed_text: str, errors: Optional[list] = None): self.processed_text = processed_text + self.errors = errors def __repr__(self) -> str: - return f"ReidentifyTextResponse(processed_text='{self.processed_text}')" + return f"ReidentifyTextResponse(processed_text='{self.processed_text}', errors={self.errors})" def __str__(self) -> str: return self.__repr__() \ No newline at end of file diff --git a/tests/client/test_skyflow.py b/tests/client/test_skyflow.py index 3e3681bb..5b7ea675 100644 --- a/tests/client/test_skyflow.py +++ b/tests/client/test_skyflow.py @@ -1,42 +1,42 @@ import unittest -from unittest.mock import patch +from unittest.mock import patch, Mock from skyflow import LogLevel, Env from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages from skyflow import Skyflow +from skyflow.vault.client.client import VaultClient +from skyflow.vault.data import FileUploadRequest VALID_VAULT_CONFIG = { "vault_id": "VAULT_ID", "cluster_id": "CLUSTER_ID", "env": Env.DEV, - "credentials": {"path": "/path/to/valid_credentials.json"} + "credentials": {"path": "/path/to/valid_credentials.json"}, } INVALID_VAULT_CONFIG = { "cluster_id": "CLUSTER_ID", # Missing vault_id "env": Env.DEV, - "credentials": {"path": "/path/to/valid_credentials.json"} + "credentials": {"path": "/path/to/valid_credentials.json"}, } VALID_CONNECTION_CONFIG = { "connection_id": "CONNECTION_ID", "connection_url": "https://CONNECTION_URL", - "credentials": {"path": "/path/to/valid_credentials.json"} + "credentials": {"path": "/path/to/valid_credentials.json"}, } INVALID_CONNECTION_CONFIG = { "connection_url": "https://CONNECTION_URL", # Missing connection_id - "credentials": {"path": "/path/to/valid_credentials.json"} + "credentials": {"path": "/path/to/valid_credentials.json"}, } -VALID_CREDENTIALS = { - "path": "/path/to/valid_credentials.json" -} +VALID_CREDENTIALS = {"path": "/path/to/valid_credentials.json"} -class TestSkyflow(unittest.TestCase): +class TestSkyflow(unittest.TestCase): def setUp(self): self.builder = Skyflow.builder() @@ -49,8 +49,10 @@ def test_add_already_exists_vault_config(self): builder = self.builder.add_vault_config(VALID_VAULT_CONFIG) with self.assertRaises(SkyflowError) as context: builder.add_vault_config(VALID_VAULT_CONFIG) - self.assertEqual(context.exception.message, SkyflowMessages.Error.VAULT_ID_ALREADY_EXISTS.value.format(VALID_VAULT_CONFIG.get("vault_id"))) - + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.VAULT_ID_ALREADY_EXISTS.value.format(VALID_VAULT_CONFIG.get("vault_id")), + ) def test_add_vault_config_invalid(self): with self.assertRaises(SkyflowError) as context: @@ -61,11 +63,11 @@ def test_add_vault_config_invalid(self): def test_remove_vault_config_valid(self): self.builder.add_vault_config(VALID_VAULT_CONFIG) self.builder.build() - result = self.builder.remove_vault_config(VALID_VAULT_CONFIG['vault_id']) + result = self.builder.remove_vault_config(VALID_VAULT_CONFIG["vault_id"]) - self.assertNotIn(VALID_VAULT_CONFIG['vault_id'], self.builder._Builder__vault_configs) + self.assertNotIn(VALID_VAULT_CONFIG["vault_id"], self.builder._Builder__vault_configs) - @patch('skyflow.utils.logger.log_error') + @patch("skyflow.utils.logger.log_error") def test_remove_vault_config_invalid(self, mock_log_error): self.builder.add_vault_config(VALID_VAULT_CONFIG) self.builder.build() @@ -73,8 +75,7 @@ def test_remove_vault_config_invalid(self, mock_log_error): self.builder.remove_vault_config("invalid_id") self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_VAULT_ID.value) - - @patch('skyflow.vault.client.client.VaultClient.update_config') + @patch("skyflow.vault.client.client.VaultClient.update_config") def test_update_vault_config_valid(self, mock_validate): self.builder.add_vault_config(VALID_VAULT_CONFIG) self.builder.build() @@ -94,7 +95,7 @@ def test_get_vault(self): def test_get_vault_with_vault_id_none(self): self.builder.add_vault_config(VALID_VAULT_CONFIG) self.builder.build() - vault = self.builder.get_vault_config(None) + vault = self.builder.get_vault_config(None) config = vault.get("vault_client").get_config() self.assertEqual(self.builder._Builder__vault_list[0], config) @@ -107,19 +108,23 @@ def test_get_vault_with_empty_vault_list_when_vault_id_is_none_raises_error(self def test_get_vault_with_invalid_vault_id_raises_error(self): self.builder.build() with self.assertRaises(SkyflowError) as context: - self.builder.get_vault_config('invalid_id') - self.assertEqual(context.exception.message, SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format('invalid_id')) + self.builder.get_vault_config("invalid_id") + self.assertEqual( + context.exception.message, SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format("invalid_id") + ) def test_get_vault_with_invalid_vault_id_and_non_empty_list_raises_error(self): self.builder.add_vault_config(VALID_VAULT_CONFIG) self.builder.build() with self.assertRaises(SkyflowError) as context: - self.builder.get_vault_config('invalid_vault_id') - - self.assertEqual(context.exception.message, SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format("invalid_vault_id")) + self.builder.get_vault_config("invalid_vault_id") + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format("invalid_vault_id"), + ) - @patch('skyflow.client.skyflow.validate_vault_config') + @patch("skyflow.client.skyflow.validate_vault_config") def test_build_calls_validate_vault_config(self, mock_validate_vault_config): self.builder.add_vault_config(VALID_VAULT_CONFIG) self.builder.build() @@ -143,7 +148,9 @@ def test_add_already_exists_connection_config(self): with self.assertRaises(SkyflowError) as context: builder.add_connection_config(VALID_CONNECTION_CONFIG) - self.assertEqual(context.exception.message, SkyflowMessages.Error.CONNECTION_ID_ALREADY_EXISTS.value.format(connection_id)) + self.assertEqual( + context.exception.message, SkyflowMessages.Error.CONNECTION_ID_ALREADY_EXISTS.value.format(connection_id) + ) def test_add_connection_config_invalid(self): with self.assertRaises(SkyflowError) as context: @@ -158,8 +165,7 @@ def test_remove_connection_config_valid(self): self.assertNotIn(VALID_CONNECTION_CONFIG.get("connection_id"), self.builder._Builder__connection_configs) - - @patch('skyflow.utils.logger.log_error') + @patch("skyflow.utils.logger.log_error") def test_remove_connection_config_invalid(self, mock_log_error): self.builder.add_connection_config(VALID_CONNECTION_CONFIG) self.builder.build() @@ -167,7 +173,7 @@ def test_remove_connection_config_invalid(self, mock_log_error): self.builder.remove_connection_config("invalid_id") self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CONNECTION_ID.value) - @patch('skyflow.vault.client.client.VaultClient.update_config') + @patch("skyflow.vault.client.client.VaultClient.update_config") def test_update_connection_config_valid(self, mock_validate): self.builder.add_connection_config(VALID_CONNECTION_CONFIG) self.builder.build() @@ -194,16 +200,21 @@ def test_get_connection_config_with_connection_id_none(self): def test_get_connection_with_empty_connection_list_raises_error(self): self.builder.build() with self.assertRaises(SkyflowError) as context: - self.builder.get_connection_config('invalid_id') - self.assertEqual(context.exception.message, SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format('invalid_id')) + self.builder.get_connection_config("invalid_id") + self.assertEqual( + context.exception.message, SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format("invalid_id") + ) def test_get_connection_with_invalid_connection_id_raises_error(self): self.builder.add_connection_config(VALID_CONNECTION_CONFIG) self.builder.build() with self.assertRaises(SkyflowError) as context: - self.builder.get_connection_config('invalid_connection_id') + self.builder.get_connection_config("invalid_connection_id") - self.assertEqual(context.exception.message, SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format('invalid_connection_id')) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format("invalid_connection_id"), + ) def test_get_connection_with_invalid_connection_id_and_empty_list_raises_Error(self): self.builder.build() @@ -212,13 +223,12 @@ def test_get_connection_with_invalid_connection_id_and_empty_list_raises_Error(s self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_CONNECTION_CONFIGS.value) - @patch('skyflow.client.skyflow.validate_connection_config') + @patch("skyflow.client.skyflow.validate_connection_config") def test_build_calls_validate_connection_config(self, mock_validate): self.builder.add_connection_config(VALID_CONNECTION_CONFIG) self.builder.build() mock_validate.assert_called_once_with(self.builder._Builder__logger, VALID_CONNECTION_CONFIG) - def test_build_valid(self): self.builder.add_vault_config(VALID_VAULT_CONFIG).add_connection_config(VALID_CONNECTION_CONFIG) client = self.builder.build() @@ -236,30 +246,31 @@ def test_invalid_credentials(self): self.assertEqual(VALID_CREDENTIALS, self.builder._Builder__skyflow_credentials) self.assertEqual(builder, self.builder) - @patch('skyflow.client.skyflow.validate_vault_config') + @patch("skyflow.client.skyflow.validate_vault_config") def test_skyflow_client_add_remove_vault_config(self, mock_validate_vault_config): skyflow_client = self.builder.add_vault_config(VALID_VAULT_CONFIG).build() new_config = VALID_VAULT_CONFIG.copy() - new_config['vault_id'] = "VAULT_ID" + new_config["vault_id"] = "VAULT_ID" skyflow_client.add_vault_config(new_config) - assert mock_validate_vault_config.call_count == 2 + self.assertEqual(mock_validate_vault_config.call_count, 2) - self.assertEqual("VAULT_ID", - skyflow_client.get_vault_config(new_config['vault_id']).get("vault_id")) + self.assertEqual("VAULT_ID", skyflow_client.get_vault_config(new_config["vault_id"]).get("vault_id")) - skyflow_client.remove_vault_config(new_config['vault_id']) + skyflow_client.remove_vault_config(new_config["vault_id"]) with self.assertRaises(SkyflowError) as context: - skyflow_client.get_vault_config(new_config['vault_id']).get("vault_id") + skyflow_client.get_vault_config(new_config["vault_id"]).get("vault_id") - self.assertEqual(context.exception.message, SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format( - new_config['vault_id'])) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format(new_config["vault_id"]), + ) - @patch('skyflow.vault.client.client.VaultClient.update_config') + @patch("skyflow.vault.client.client.VaultClient.update_config") def test_skyflow_client_update_and_get_vault_config(self, mock_update_config): skyflow_client = self.builder.add_vault_config(VALID_VAULT_CONFIG).build() new_config = VALID_VAULT_CONFIG.copy() - new_config['env'] = Env.SANDBOX + new_config["env"] = Env.SANDBOX skyflow_client.update_vault_config(new_config) mock_update_config.assert_called_once() @@ -267,29 +278,33 @@ def test_skyflow_client_update_and_get_vault_config(self, mock_update_config): self.assertEqual(VALID_VAULT_CONFIG.get("vault_id"), vault.get("vault_id")) - @patch('skyflow.client.skyflow.validate_connection_config') + @patch("skyflow.client.skyflow.validate_connection_config") def test_skyflow_client_add_remove_connection_config(self, mock_validate_connection_config): skyflow_client = self.builder.add_connection_config(VALID_CONNECTION_CONFIG).build() new_config = VALID_CONNECTION_CONFIG.copy() - new_config['connection_id'] = "CONNECTION_ID" + new_config["connection_id"] = "CONNECTION_ID" skyflow_client.add_connection_config(new_config) - assert mock_validate_connection_config.call_count == 2 - self.assertEqual("CONNECTION_ID", skyflow_client.get_connection_config(new_config['connection_id']).get("connection_id")) + self.assertEqual(mock_validate_connection_config.call_count, 2) + self.assertEqual( + "CONNECTION_ID", skyflow_client.get_connection_config(new_config["connection_id"]).get("connection_id") + ) skyflow_client.remove_connection_config("CONNECTION_ID") with self.assertRaises(SkyflowError) as context: - skyflow_client.get_connection_config(new_config['connection_id']).get("connection_id") - - self.assertEqual(context.exception.message, SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format(new_config['connection_id'])) + skyflow_client.get_connection_config(new_config["connection_id"]).get("connection_id") + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format(new_config["connection_id"]), + ) - @patch('skyflow.vault.client.client.VaultClient.update_config') + @patch("skyflow.vault.client.client.VaultClient.update_config") def test_skyflow_client_update_and_get_connection_config(self, mock_update_config): builder = self.builder skyflow_client = builder.add_connection_config(VALID_CONNECTION_CONFIG).build() new_config = VALID_CONNECTION_CONFIG.copy() - new_config['connection_url'] = 'updated_url' + new_config["connection_url"] = "updated_url" skyflow_client.update_connection_config(new_config) mock_update_config.assert_called_once() @@ -305,28 +320,165 @@ def test_skyflow_add_and_update_skyflow_credentials(self): self.assertEqual(VALID_CREDENTIALS, builder._Builder__skyflow_credentials) new_credentials = VALID_CREDENTIALS.copy() - new_credentials['path'] = 'path/to/new_credentials' + new_credentials["path"] = "path/to/new_credentials" skyflow_client.update_skyflow_credentials(new_credentials) self.assertEqual(new_credentials, builder._Builder__skyflow_credentials) - def test_skyflow_add_and_update_log_level(self): builder = self.builder - skyflow_client = builder.add_connection_config(VALID_CONNECTION_CONFIG).build() + skyflow_client = builder.add_connection_config(VALID_CONNECTION_CONFIG).build() skyflow_client.set_log_level(LogLevel.INFO) self.assertEqual(LogLevel.INFO, builder._Builder__log_level) - skyflow_client.update_log_level(LogLevel.ERROR) - self.assertEqual(LogLevel.ERROR, builder._Builder__log_level) - - - @patch('skyflow.client.Skyflow.Builder.get_vault_config') + @patch("skyflow.client.Skyflow.Builder.get_vault_config") def test_skyflow_vault_and_connection_method(self, mock_get_vault_config): builder = self.builder - skyflow_client = builder.add_connection_config(VALID_CONNECTION_CONFIG).add_vault_config(VALID_VAULT_CONFIG).build() + skyflow_client = ( + builder.add_connection_config(VALID_CONNECTION_CONFIG).add_vault_config(VALID_VAULT_CONFIG).build() + ) skyflow_client.vault() skyflow_client.connection() - mock_get_vault_config.assert_called_once() \ No newline at end of file + mock_get_vault_config.assert_called_once() + + def test_detect_returns_detect_controller(self): + skyflow_client = self.builder.add_vault_config(VALID_VAULT_CONFIG).build() + from skyflow.vault.controller import Detect + result = skyflow_client.detect() + self.assertIsInstance(result, Detect) + + def test_detect_with_explicit_vault_id(self): + skyflow_client = self.builder.add_vault_config(VALID_VAULT_CONFIG).build() + from skyflow.vault.controller import Detect + result = skyflow_client.detect(VALID_VAULT_CONFIG["vault_id"]) + self.assertIsInstance(result, Detect) + + def test_detect_with_invalid_vault_id_raises_error(self): + skyflow_client = self.builder.add_vault_config(VALID_VAULT_CONFIG).build() + with self.assertRaises(SkyflowError) as context: + skyflow_client.detect("invalid_vault_id") + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format("invalid_vault_id"), + ) + + @patch("skyflow.vault.client.client.VaultClient.update_config") + def test_update_vault_config_with_invalid_vault_id_raises_error(self, _mock): + skyflow_client = self.builder.add_vault_config(VALID_VAULT_CONFIG).build() + invalid_config = VALID_VAULT_CONFIG.copy() + invalid_config["vault_id"] = "non_existent_vault_id" + with self.assertRaises(SkyflowError) as context: + skyflow_client.update_vault_config(invalid_config) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format("non_existent_vault_id"), + ) + + @patch("skyflow.vault.client.client.VaultClient.update_config") + def test_update_connection_config_with_invalid_connection_id_raises_error(self, _mock): + skyflow_client = self.builder.add_connection_config(VALID_CONNECTION_CONFIG).build() + invalid_config = VALID_CONNECTION_CONFIG.copy() + invalid_config["connection_id"] = "non_existent_connection_id" + with self.assertRaises(SkyflowError) as context: + skyflow_client.update_connection_config(invalid_config) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format("non_existent_connection_id"), + ) + + +class TestVaultClient(unittest.TestCase): + def _make_client(self): + client = VaultClient({"vault_id": "test_vault"}) + client._VaultClient__api_client = Mock() + return client + + def test_get_detect_text_api_returns_strings(self): + client = self._make_client() + result = client.get_detect_text_api() + self.assertEqual(result, client._VaultClient__api_client.strings) + + def test_get_detect_file_api_returns_files(self): + client = self._make_client() + result = client.get_detect_file_api() + self.assertEqual(result, client._VaultClient__api_client.files) + + @patch("skyflow.vault.client.client.generate_bearer_token_from_creds") + @patch("skyflow.vault.client.client.is_expired", return_value=True) + def test_get_bearer_token_passes_token_uri_option(self, _mock_expired, mock_gen): + mock_gen.return_value = ("test_token", "bearer") + client = VaultClient({"vault_id": "test_vault"}) + credentials = { + "credentials_string": '{"clientID":"id","privateKey":"pk","keyID":"kid","tokenURI":"https://token.uri"}', + "token_uri": "https://custom-token-uri.com/token", + } + client.get_bearer_token(credentials) + options_passed = mock_gen.call_args[0][1] + self.assertIn("token_uri", options_passed) + self.assertEqual(options_passed["token_uri"], "https://custom-token-uri.com/token") + + +class TestUpdateLogLevelDeprecation(unittest.TestCase): + def _build_client(self): + return Skyflow.builder().add_vault_config(VALID_VAULT_CONFIG).build() + + def test_update_log_level_emits_deprecation_warning(self): + client = self._build_client() + with patch('skyflow.client.skyflow.log_warn') as mock_warn: + client.update_log_level(LogLevel.INFO) + mock_warn.assert_called_once() + self.assertIn("set_log_level", mock_warn.call_args[0][0]) + + def test_update_log_level_delegates_to_set_log_level(self): + client = self._build_client() + client.update_log_level(LogLevel.INFO) + self.assertEqual(client.get_log_level(), LogLevel.INFO) + + +class TestFileUploadRequestDeprecation(unittest.TestCase): + def test_keyword_args_no_warning(self): + with patch('skyflow.vault.data._file_upload_request.log_warn') as mock_warn: + req = FileUploadRequest( + table="table", + column_name="col", + skyflow_id="sky123", + ) + mock_warn.assert_not_called() + self.assertEqual(req.table, "table") + self.assertEqual(req.column_name, "col") + self.assertEqual(req.skyflow_id, "sky123") + + def test_only_table_positional_no_warning(self): + with patch('skyflow.vault.data._file_upload_request.log_warn') as mock_warn: + req = FileUploadRequest("table", column_name="col", skyflow_id="sky123") + mock_warn.assert_not_called() + self.assertEqual(req.table, "table") + + def test_old_positional_order_emits_deprecation_warning(self): + with patch('skyflow.vault.data._file_upload_request.log_warn') as mock_warn: + req = FileUploadRequest("table", "sky123", "col") + mock_warn.assert_called_once() + self.assertIn("FileUploadRequest", mock_warn.call_args[0][0]) + + def test_old_positional_order_remaps_args_correctly(self): + req = FileUploadRequest("table", "sky123", "col") + self.assertEqual(req.skyflow_id, "sky123") + self.assertEqual(req.column_name, "col") + + def test_single_positional_arg_emits_warning_and_sets_skyflow_id(self): + with patch('skyflow.vault.data._file_upload_request.log_warn') as mock_warn: + req = FileUploadRequest("table", "sky123") + mock_warn.assert_called_once() + self.assertEqual(req.skyflow_id, "sky123") + self.assertIsNone(req.column_name) + + def test_optional_fields_default_to_none(self): + req = FileUploadRequest(table="table") + self.assertIsNone(req.skyflow_id) + self.assertIsNone(req.column_name) + self.assertIsNone(req.file_path) + self.assertIsNone(req.base64) + self.assertIsNone(req.file_object) + self.assertIsNone(req.file_name) diff --git a/tests/service_account/test__utils.py b/tests/service_account/test__utils.py index 7ffb36df..505a7261 100644 --- a/tests/service_account/test__utils.py +++ b/tests/service_account/test__utils.py @@ -5,35 +5,57 @@ from unittest.mock import patch import os from skyflow.error import SkyflowError -from skyflow.service_account import is_expired, generate_bearer_token, \ - generate_bearer_token_from_creds +from skyflow.service_account import is_expired, generate_bearer_token, generate_bearer_token_from_creds from skyflow.utils import SkyflowMessages -from skyflow.service_account._utils import get_service_account_token, get_signed_jwt, generate_signed_data_tokens, get_signed_data_token_response_object, generate_signed_data_tokens_from_creds +from skyflow.service_account._utils import ( + get_service_account_token, + get_signed_jwt, + generate_signed_data_tokens, + get_signed_data_token_response_object, + generate_signed_data_tokens_from_creds, + _validate_and_resolve_ctx, + _normalize_credentials, + get_signed_tokens, +) creds_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "credentials.json") -with open(creds_path, 'r') as file: +with open(creds_path, "r") as file: credentials = json.load(file) VALID_CREDENTIALS_STRING = json.dumps(credentials) -CREDENTIALS_WITHOUT_CLIENT_ID = { - 'privateKey': 'private_key' -} +CREDENTIALS_WITHOUT_CLIENT_ID = {"privateKey": "private_key"} -CREDENTIALS_WITHOUT_KEY_ID = { - 'privateKey': 'private_key', - 'clientID': 'client_id' -} +CREDENTIALS_WITHOUT_KEY_ID = {"privateKey": "private_key", "clientID": "client_id"} -CREDENTIALS_WITHOUT_TOKEN_URI = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id' -} +CREDENTIALS_WITHOUT_TOKEN_URI = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id"} VALID_SERVICE_ACCOUNT_CREDS = credentials +# Snake-case version of the real credentials (keys remapped to snake_case) +SNAKE_CASE_CREDS = { + "private_key": credentials["privateKey"], + "client_id": credentials["clientID"], + "key_id": credentials["keyID"], + "token_uri": credentials["tokenURI"], +} + +SNAKE_CASE_CREDS_STRING = json.dumps( + { + "private_key": credentials["privateKey"], + "client_id": credentials["clientID"], + "key_id": credentials["keyID"], + "token_uri": credentials["tokenURI"], + } +) + + class TestServiceAccountUtils(unittest.TestCase): + # ── is_expired ──────────────────────────────────────────────────────────── + + def test_is_expired_none_token(self): + self.assertTrue(is_expired(None)) + def test_is_expired_empty_token(self): self.assertTrue(is_expired("")) @@ -44,7 +66,7 @@ def test_is_expired_non_expired_token(self): def test_is_expired_expired_token(self): past_time = time.time() - 1000 - token = jwt.encode({"exp": past_time}, key="test", algorithm="HS256") + token = jwt.encode({"exp": past_time}, key="test", algorithm="HS256") self.assertTrue(is_expired(token)) @patch("skyflow.utils.logger._log_helpers.log_error_log") @@ -53,6 +75,8 @@ def test_is_expired_general_exception(self, mock_jwt_decode, mock_log_error): token = jwt.encode({"exp": time.time() + 1000}, key="test", algorithm="HS256") self.assertTrue(is_expired(token)) + # ── generate_bearer_token ───────────────────────────────────────────────── + @patch("builtins.open", side_effect=FileNotFoundError) def test_generate_bearer_token_invalid_file_path(self, mock_open): with self.assertRaises(SkyflowError) as context: @@ -72,6 +96,8 @@ def test_generate_bearer_token_valid_file_path(self, mock_generate_bearer_token) generate_bearer_token(creds_path) mock_generate_bearer_token.assert_called_once() + # ── generate_bearer_token_from_creds ────────────────────────────────────── + @patch("skyflow.service_account._utils.get_service_account_token") def test_generate_bearer_token_from_creds_with_valid_json_string(self, mock_generate_bearer_token): generate_bearer_token_from_creds(VALID_CREDENTIALS_STRING) @@ -82,10 +108,11 @@ def test_generate_bearer_token_from_creds_invalid_json(self): generate_bearer_token_from_creds("invalid_json") self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value) + # ── get_service_account_token ───────────────────────────────────────────── + def test_get_service_account_token_missing_private_key(self): - incomplete_credentials = {} with self.assertRaises(SkyflowError) as context: - get_service_account_token(incomplete_credentials, {}, None) + get_service_account_token({}, {}, None) self.assertEqual(context.exception.message, SkyflowMessages.Error.MISSING_PRIVATE_KEY.value) def test_get_service_account_token_missing_client_id_key(self): @@ -107,6 +134,102 @@ def test_get_service_account_token_with_valid_credentials(self): access_token, _ = get_service_account_token(VALID_SERVICE_ACCOUNT_CREDS, {}, None) self.assertTrue(access_token) + def test_get_service_account_token_with_snake_case_creds(self): + access_token, _ = get_service_account_token(SNAKE_CASE_CREDS, {}, None) + self.assertTrue(access_token) + + def test_get_service_account_token_missing_private_key_snake(self): + creds = { + "client_id": "id", + "key_id": "kid", + "token_uri": "https://example.com", + } + with self.assertRaises(SkyflowError) as context: + get_service_account_token(creds, {}, None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.MISSING_PRIVATE_KEY.value) + + def test_get_service_account_token_invalid_token_uri(self): + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "not-a-url", + } + with self.assertRaises(SkyflowError) as context: + get_service_account_token(creds, {}, None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_get_service_account_token_invalid_token_uri_in_options(self): + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "https://valid-url.com", + } + options = {"token_uri": "not-a-valid-url"} + with self.assertRaises(SkyflowError) as context: + get_service_account_token(creds, options, None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + @patch("skyflow.service_account._utils.AuthClient") + @patch("skyflow.service_account._utils.get_signed_jwt") + def test_get_service_account_token_with_role_ids_formats_scope(self, mock_get_signed_jwt, mock_auth_client): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", + } + options = {"role_ids": ["role1", "role2"]} + mock_get_signed_jwt.return_value = "signed" + mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value + mock_auth_api.authentication_service_get_auth_token.return_value = type( + "obj", (), {"access_token": "token", "token_type": "bearer"} + ) + access_token, token_type = get_service_account_token(creds, options, None) + self.assertEqual(access_token, "token") + self.assertEqual(token_type, "bearer") + args, kwargs = mock_auth_api.authentication_service_get_auth_token.call_args + self.assertIn("scope", kwargs) + self.assertEqual(kwargs["scope"], "role:role1 role:role2") + + @patch("skyflow.service_account._utils.AuthClient") + @patch("skyflow.service_account._utils.get_signed_jwt") + def test_get_service_account_token_unauthorized_error(self, mock_get_signed_jwt, mock_auth_client): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", + } + mock_get_signed_jwt.return_value = "signed" + mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value + from skyflow.generated.rest.errors.unauthorized_error import UnauthorizedError + + mock_auth_api.authentication_service_get_auth_token.side_effect = UnauthorizedError("unauthorized") + with self.assertRaises(SkyflowError) as context: + get_service_account_token(creds, {}, None) + self.assertEqual( + context.exception.message, SkyflowMessages.Error.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value + ) + + @patch("skyflow.service_account._utils.AuthClient") + @patch("skyflow.service_account._utils.get_signed_jwt") + def test_get_service_account_token_generic_exception(self, mock_get_signed_jwt, mock_auth_client): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", + } + mock_get_signed_jwt.return_value = "signed" + mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value + mock_auth_api.authentication_service_get_auth_token.side_effect = Exception("some error") + with self.assertRaises(SkyflowError) as context: + get_service_account_token(creds, {}, None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.FAILED_TO_GET_BEARER_TOKEN.value) + + # ── get_signed_jwt ──────────────────────────────────────────────────────── @patch("jwt.encode", side_effect=Exception) def test_get_signed_jwt_invalid_format(self, mock_jwt_encode): @@ -114,25 +237,178 @@ def test_get_signed_jwt_invalid_format(self, mock_jwt_encode): get_signed_jwt({}, "client_id", "key_id", "token_uri", "private_key", None) self.assertEqual(context.exception.message, SkyflowMessages.Error.JWT_INVALID_FORMAT.value) + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_jwt_with_valid_string_ctx(self, mock_jwt_encode): + mock_jwt_encode.return_value = "mock_token" + get_signed_jwt({"ctx": "valid_ctx"}, "client_id", "key_id", "token_uri", "private_key", None) + payload = mock_jwt_encode.call_args.kwargs["payload"] + self.assertEqual(payload["ctx"], "valid_ctx") + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_jwt_with_valid_dict_ctx(self, mock_jwt_encode): + mock_jwt_encode.return_value = "mock_token" + get_signed_jwt({"ctx": {"role": "admin"}}, "client_id", "key_id", "token_uri", "private_key", None) + payload = mock_jwt_encode.call_args.kwargs["payload"] + self.assertEqual(payload["ctx"], {"role": "admin"}) + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_jwt_with_empty_string_ctx_not_added(self, mock_jwt_encode): + mock_jwt_encode.return_value = "mock_token" + get_signed_jwt({"ctx": ""}, "client_id", "key_id", "token_uri", "private_key", None) + payload = mock_jwt_encode.call_args.kwargs["payload"] + self.assertNotIn("ctx", payload) + + # ── get_signed_data_token_response_object ───────────────────────────────── + def test_get_signed_data_token_response_object(self): token = "sample_token" signed_token = "signed_sample_token" response = get_signed_data_token_response_object(signed_token, token) + self.assertIsInstance(response, tuple) self.assertEqual(response[0], token) self.assertEqual(response[1], signed_token) + # ── get_signed_tokens ───────────────────────────────────────────────────── + + @patch("jwt.encode", side_effect=Exception("jwt error")) + def test_get_signed_tokens_jwt_encode_exception(self, mock_jwt_encode): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", + } + options = {"data_tokens": ["token1"]} + with self.assertRaises(SkyflowError) as context: + get_signed_tokens(creds, options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS.value) + + def test_get_signed_tokens_returns_list_one_per_token(self): + result = generate_signed_data_tokens(creds_path, {"data_tokens": ["token1", "token2"]}) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + + def test_get_signed_tokens_items_are_tuples_with_token_and_signed_token(self): + result = generate_signed_data_tokens(creds_path, {"data_tokens": ["token1", "token2"]}) + for item in result: + self.assertIsInstance(item, tuple) + self.assertEqual(result[0][0], "token1") + self.assertEqual(result[1][0], "token2") + self.assertTrue(result[0][1].startswith("signed_token_")) + self.assertTrue(result[1][1].startswith("signed_token_")) + + def test_get_signed_tokens_returns_list_single_token(self): + result = generate_signed_data_tokens(creds_path, {"data_tokens": ["token1"]}) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + + def test_get_signed_tokens_empty_data_tokens_returns_empty_list(self): + result = generate_signed_data_tokens(creds_path, {"data_tokens": []}) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 0) + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_tokens_with_string_ctx_in_claims(self, mock_jwt_encode): + mock_jwt_encode.return_value = "signed" + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "https://valid-url.com", + } + get_signed_tokens(creds, {"data_tokens": ["tok1"], "ctx": "my_ctx"}) + call_args = mock_jwt_encode.call_args + claims = call_args[0][0] if call_args[0] else call_args.kwargs.get("args", [None])[0] + # jwt.encode(claims, key, algorithm=...) — first positional arg is claims + claims_arg = mock_jwt_encode.call_args[0][0] + self.assertEqual(claims_arg["ctx"], "my_ctx") + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_tokens_with_dict_ctx_in_claims(self, mock_jwt_encode): + mock_jwt_encode.return_value = "signed" + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "https://valid-url.com", + } + ctx_dict = {"role": "admin", "dept": "eng"} + get_signed_tokens(creds, {"data_tokens": ["tok1"], "ctx": ctx_dict}) + claims_arg = mock_jwt_encode.call_args[0][0] + self.assertEqual(claims_arg["ctx"], ctx_dict) + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_tokens_with_empty_ctx_not_in_claims(self, mock_jwt_encode): + mock_jwt_encode.return_value = "signed" + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "https://valid-url.com", + } + get_signed_tokens(creds, {"data_tokens": ["tok1"], "ctx": ""}) + claims_arg = mock_jwt_encode.call_args[0][0] + self.assertNotIn("ctx", claims_arg) + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_tokens_with_none_ctx_not_in_claims(self, mock_jwt_encode): + mock_jwt_encode.return_value = "signed" + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "https://valid-url.com", + } + get_signed_tokens(creds, {"data_tokens": ["tok1"], "ctx": None}) + claims_arg = mock_jwt_encode.call_args[0][0] + self.assertNotIn("ctx", claims_arg) + + def test_get_signed_tokens_invalid_token_uri(self): + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "not-a-url", + } + with self.assertRaises(SkyflowError) as context: + get_signed_tokens(creds, {"data_tokens": ["tok1"]}) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_get_signed_tokens_missing_token_uri(self): + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + } + with self.assertRaises(SkyflowError) as context: + get_signed_tokens(creds, {"data_tokens": ["tok1"]}) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_get_signed_tokens_with_snake_case_creds(self): + result = get_signed_tokens(SNAKE_CASE_CREDS, {"data_tokens": ["token1", "token2"]}) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + + # ── generate_signed_data_tokens (file path) ─────────────────────────────── + def test_generate_signed_data_tokens_from_file_path(self): - creds_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "credentials.json") - options = {"data_tokens": ["token1", "token2"], "ctx": 'ctx'} + options = {"data_tokens": ["token1", "token2"], "ctx": "ctx"} result = generate_signed_data_tokens(creds_path, options) self.assertEqual(len(result), 2) def test_generate_signed_data_tokens_from_invalid_file_path(self): options = {"data_tokens": ["token1", "token2"]} with self.assertRaises(SkyflowError) as context: - result = generate_signed_data_tokens('credentials1.json', options) + generate_signed_data_tokens("credentials1.json", options) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value) + def test_generate_signed_data_tokens_with_dict_ctx(self): + options = {"data_tokens": ["token1"], "ctx": {"role": "admin", "department": "finance"}} + result = generate_signed_data_tokens(creds_path, options) + self.assertEqual(len(result), 1) + + # ── generate_signed_data_tokens_from_creds (string) ────────────────────── + def test_generate_signed_data_tokens_from_creds(self): options = {"data_tokens": ["token1", "token2"]} result = generate_signed_data_tokens_from_creds(VALID_CREDENTIALS_STRING, options) @@ -140,7 +416,177 @@ def test_generate_signed_data_tokens_from_creds(self): def test_generate_signed_data_tokens_from_creds_with_invalid_string(self): options = {"data_tokens": ["token1", "token2"]} - credentials_string = '{' with self.assertRaises(SkyflowError) as context: - result = generate_signed_data_tokens_from_creds(credentials_string, options) - self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value) \ No newline at end of file + generate_signed_data_tokens_from_creds("{", options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value) + + def test_generate_signed_data_tokens_from_creds_with_dict_ctx(self): + options = {"data_tokens": ["token1"], "ctx": {"role": "admin", "level": 3}} + result = generate_signed_data_tokens_from_creds(VALID_CREDENTIALS_STRING, options) + self.assertEqual(len(result), 1) + + # ── snake_case end-to-end ───────────────────────────────────────────────── + + def test_generate_signed_data_tokens_with_snake_creds_file(self): + """generate_signed_data_tokens reads the file (camelCase) but the normalize fn is a no-op for camelCase.""" + options = {"data_tokens": ["token1", "token2"]} + result = generate_signed_data_tokens(creds_path, options) + self.assertEqual(len(result), 2) + + def test_generate_signed_data_tokens_from_creds_snake(self): + result = generate_signed_data_tokens_from_creds(SNAKE_CASE_CREDS_STRING, options={"data_tokens": ["t1"]}) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + + # ── _normalize_credentials ──────────────────────────────────────────────── + + def test_normalize_credentials_snake_case(self): + snake = { + "private_key": "pk", + "client_id": "cid", + "key_id": "kid", + "token_uri": "https://uri", + "client_name": "name", + } + result = _normalize_credentials(snake) + self.assertEqual(result["privateKey"], "pk") + self.assertEqual(result["clientID"], "cid") + self.assertEqual(result["keyID"], "kid") + self.assertEqual(result["tokenURI"], "https://uri") + self.assertEqual(result["clientName"], "name") + self.assertNotIn("private_key", result) + self.assertNotIn("client_id", result) + self.assertNotIn("key_id", result) + self.assertNotIn("token_uri", result) + self.assertNotIn("client_name", result) + + def test_normalize_credentials_camel_case_unchanged(self): + camel = { + "privateKey": "pk", + "clientID": "cid", + "keyID": "kid", + "tokenURI": "https://uri", + } + result = _normalize_credentials(camel) + self.assertEqual(result, camel) + + def test_normalize_credentials_mixed_keys(self): + mixed = { + "private_key": "pk", + "clientID": "cid", + "key_id": "kid", + "tokenURI": "https://uri", + } + result = _normalize_credentials(mixed) + self.assertEqual(result["privateKey"], "pk") + self.assertEqual(result["clientID"], "cid") + self.assertEqual(result["keyID"], "kid") + self.assertEqual(result["tokenURI"], "https://uri") + self.assertNotIn("private_key", result) + self.assertNotIn("key_id", result) + + def test_normalize_credentials_unknown_key_passes_through(self): + creds = {"unknown_field": "value", "anotherField": "val2"} + result = _normalize_credentials(creds) + self.assertEqual(result["unknown_field"], "value") + self.assertEqual(result["anotherField"], "val2") + + def test_normalize_credentials_empty_dict(self): + self.assertEqual(_normalize_credentials({}), {}) + + # ── _validate_and_resolve_ctx ───────────────────────────────────────────── + + def test_validate_and_resolve_ctx_none(self): + self.assertIsNone(_validate_and_resolve_ctx(None)) + + def test_validate_and_resolve_ctx_empty_string(self): + self.assertIsNone(_validate_and_resolve_ctx("")) + self.assertIsNone(_validate_and_resolve_ctx(" ")) + + def test_validate_and_resolve_ctx_valid_string(self): + self.assertEqual(_validate_and_resolve_ctx("user_12345"), "user_12345") + + def test_validate_and_resolve_ctx_empty_dict(self): + self.assertIsNone(_validate_and_resolve_ctx({})) + + def test_validate_and_resolve_ctx_valid_dict(self): + ctx = {"role": "admin", "department": "finance"} + self.assertEqual(_validate_and_resolve_ctx(ctx), ctx) + + def test_validate_and_resolve_ctx_dict_with_alphanumeric_keys(self): + ctx = {"role_1": "admin", "dept2": "finance", "ABC_123": "value"} + self.assertEqual(_validate_and_resolve_ctx(ctx), ctx) + + def test_validate_and_resolve_ctx_dict_with_invalid_key_hyphen(self): + with self.assertRaises(SkyflowError): + _validate_and_resolve_ctx({"valid_key": "value", "invalid-key": "value"}) + + def test_validate_and_resolve_ctx_dict_with_invalid_key_space(self): + with self.assertRaises(SkyflowError): + _validate_and_resolve_ctx({"invalid key": "value"}) + + def test_validate_and_resolve_ctx_dict_with_invalid_key_dot(self): + with self.assertRaises(SkyflowError): + _validate_and_resolve_ctx({"invalid.key": "value"}) + + def test_validate_and_resolve_ctx_valid_type_int(self): + self.assertEqual(_validate_and_resolve_ctx(42), 42) + + def test_validate_and_resolve_ctx_valid_type_float(self): + self.assertEqual(_validate_and_resolve_ctx(3.14), 3.14) + + def test_validate_and_resolve_ctx_valid_type_bool_true(self): + self.assertEqual(_validate_and_resolve_ctx(True), True) + + def test_validate_and_resolve_ctx_valid_type_bool_false(self): + self.assertEqual(_validate_and_resolve_ctx(False), False) + + def test_validate_and_resolve_ctx_invalid_type_list(self): + with self.assertRaises(SkyflowError): + _validate_and_resolve_ctx(["a", "b"]) + + def test_validate_and_resolve_ctx_dict_with_mixed_value_types(self): + ctx = {"role": "admin", "level": 3, "active": True, "timestamp": "2025-12-25T10:30:00Z"} + self.assertEqual(_validate_and_resolve_ctx(ctx), ctx) + + def test_validate_and_resolve_ctx_dict_with_nested_objects(self): + ctx = {"role": "admin", "metadata": {"level": 2, "tags": ["a", "b"]}} + self.assertEqual(_validate_and_resolve_ctx(ctx), ctx) + + # ── additional coverage gaps ────────────────────────────────────────────── + + @patch("skyflow.service_account._utils.jwt.decode", side_effect=jwt.ExpiredSignatureError) + def test_is_expired_expired_signature_error(self, mock_decode): + token = jwt.encode({"exp": time.time() + 1000}, key="test", algorithm="HS256") + self.assertTrue(is_expired(token)) + + @patch("skyflow.service_account._utils.AuthClient") + @patch("skyflow.service_account._utils.get_signed_jwt") + def test_get_service_account_token_with_token_uri_option_override(self, mock_get_signed_jwt, mock_auth_client): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", + } + override_uri = "https://override-url.com" + options = {"token_uri": override_uri} + mock_get_signed_jwt.return_value = "signed" + mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value + mock_auth_api.authentication_service_get_auth_token.return_value = type( + "obj", (), {"access_token": "token", "token_type": "bearer"} + ) + get_service_account_token(creds, options, None) + mock_get_signed_jwt.assert_called_once() + call_args = mock_get_signed_jwt.call_args + self.assertEqual(call_args[0][3], override_uri) + + @patch("json.load", side_effect=json.JSONDecodeError("bad json", "", 0)) + def test_generate_signed_data_tokens_from_file_invalid_json(self, mock_load): + invalid_path = os.path.join(os.path.dirname(__file__), "invalid_creds.json") + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens(invalid_path, {"data_tokens": ["t1"]}) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.FILE_INVALID_JSON.value.format(invalid_path), + ) diff --git a/tests/utils/test__helpers.py b/tests/utils/test__helpers.py index 8b55abf3..6016c798 100644 --- a/tests/utils/test__helpers.py +++ b/tests/utils/test__helpers.py @@ -1,5 +1,5 @@ import unittest -from skyflow.utils import get_base_url, format_scope +from skyflow.utils import get_base_url, format_scope, is_valid_url VALID_URL = "https://example.com/path?query=1" BASE_URL = "https://example.com" @@ -35,4 +35,28 @@ def test_format_scope_single_scope(self): def test_format_scope_special_characters(self): scopes_with_special_chars = ["admin", "user:write", "read-only"] expected_result = "role:admin role:user:write role:read-only" - self.assertEqual(format_scope(scopes_with_special_chars), expected_result) \ No newline at end of file + self.assertEqual(format_scope(scopes_with_special_chars), expected_result) + + def test_is_valid_url_valid(self): + self.assertTrue(is_valid_url("https://example.com")) + self.assertTrue(is_valid_url("https://example.com/path")) + + def test_is_valid_url_invalid(self): + self.assertFalse(is_valid_url("http://example.com")) + self.assertFalse(is_valid_url("ftp://example.com")) + self.assertFalse(is_valid_url("example.com")) + self.assertFalse(is_valid_url("invalid-url")) + self.assertFalse(is_valid_url("")) + + def test_is_valid_url_none(self): + self.assertFalse(is_valid_url(None)) + + def test_is_valid_url_no_scheme(self): + self.assertFalse(is_valid_url("www.example.com")) + + def test_is_valid_url_exception(self): + class BadStr: + def __str__(self): + raise Exception("bad str") + + self.assertFalse(is_valid_url(BadStr())) \ No newline at end of file diff --git a/tests/utils/test__utils.py b/tests/utils/test__utils.py index 6eaacf47..1363ad7d 100644 --- a/tests/utils/test__utils.py +++ b/tests/utils/test__utils.py @@ -1,38 +1,65 @@ import unittest -from unittest.mock import patch, Mock +from unittest.mock import patch, Mock, MagicMock, PropertyMock import os -import json -from unittest.mock import MagicMock from urllib.parse import quote +import tempfile, json from requests import PreparedRequest from requests.models import HTTPError from skyflow.error import SkyflowError from skyflow.generated.rest import ErrorResponse -from skyflow.utils import get_credentials, SkyflowMessages, get_vault_url, construct_invoke_connection_request, \ - parse_insert_response, parse_update_record_response, parse_delete_response, parse_get_response, \ - parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_invoke_connection_response, \ - handle_exception, validate_api_key, encode_column_values, parse_deidentify_text_response, \ - parse_reidentify_text_response, convert_detected_entity_to_entity_info -from skyflow.utils._utils import parse_path_params, to_lowercase_keys, get_metrics, handle_json_error +from skyflow.service_account import ( + generate_bearer_token, + generate_signed_data_tokens, + generate_signed_data_tokens_from_creds, + generate_bearer_token_from_creds, +) +from skyflow.utils import ( + get_credentials, + SkyflowMessages, + get_vault_url, + construct_invoke_connection_request, + parse_insert_response, + parse_update_record_response, + parse_delete_response, + parse_get_response, + parse_detokenize_response, + parse_tokenize_response, + parse_query_response, + parse_invoke_connection_response, + handle_exception, + validate_api_key, + encode_column_values, + parse_deidentify_text_response, + parse_reidentify_text_response, + convert_detected_entity_to_entity_info, +) +from skyflow.utils._utils import parse_path_params, to_lowercase_keys, get_metrics, handle_json_error, r_urlencode from skyflow.utils.enums import EnvUrls, Env, ContentType from skyflow.vault.connection import InvokeConnectionResponse from skyflow.vault.data import InsertResponse, DeleteResponse, GetResponse, QueryResponse from skyflow.vault.tokens import DetokenizeResponse, TokenizeResponse creds_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "credentials.json") -with open(creds_path, 'r') as file: +with open(creds_path, "r") as file: credentials = json.load(file) TEST_ERROR_MESSAGE = "Test error message." VALID_ENV_CREDENTIALS = credentials -class TestUtils(unittest.TestCase): +class TestUtils(unittest.TestCase): @patch.dict(os.environ, {"SKYFLOW_CREDENTIALS": json.dumps(VALID_ENV_CREDENTIALS)}) def test_get_credentials_env_variable(self): credentials = get_credentials() - credentials_string = credentials.get('credentials_string') - self.assertEqual(credentials_string, json.dumps(VALID_ENV_CREDENTIALS).replace('\n', '\\n')) + credentials_string = credentials.get("credentials_string") + self.assertEqual(credentials_string, json.dumps(VALID_ENV_CREDENTIALS).replace("\n", "\\n")) + + @patch("skyflow.utils._utils.dotenv.find_dotenv", return_value=None) + @patch.dict(os.environ, {}, clear=True) + def test_get_credentials_no_credentials_raises(self, mock_find_dotenv): + with self.assertRaises(SkyflowError) as context: + get_credentials(config_level_creds=None, common_skyflow_creds=None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS.value) def test_get_credentials_with_config_level_creds(self): test_creds = {"authToken": "test_token"} @@ -58,11 +85,13 @@ def test_get_vault_url_with_invalid_cluster_id(self): valid_vault_id = "vault123" with self.assertRaises(SkyflowError) as context: url = get_vault_url(valid_cluster_id, valid_env, valid_vault_id) - self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CLUSTER_ID.value.format(valid_vault_id)) + self.assertEqual( + context.exception.message, SkyflowMessages.Error.INVALID_CLUSTER_ID.value.format(valid_vault_id) + ) def test_get_vault_url_with_invalid_env(self): valid_cluster_id = "cluster_id" - valid_env =EnvUrls.DEV + valid_env = EnvUrls.DEV valid_vault_id = "vault123" with self.assertRaises(SkyflowError) as context: url = get_vault_url(valid_cluster_id, valid_env, valid_vault_id) @@ -77,7 +106,7 @@ def test_handle_json_error_with_dict_data(self, mock_log_and_reject_error): "http_code": 400, "http_status": "Bad Request", "grpc_code": 3, - "details": ["detail1"] + "details": ["detail1"], } } @@ -88,13 +117,7 @@ def test_handle_json_error_with_dict_data(self, mock_log_and_reject_error): handle_json_error(mock_error, error_dict, request_id, mock_logger) mock_log_and_reject_error.assert_called_once_with( - "Dict error message", - 400, - request_id, - "Bad Request", - 3, - ["detail1"], - logger=mock_logger + "Dict error message", 400, request_id, "Bad Request", 3, ["detail1"], logger=mock_logger ) @patch("skyflow.utils._utils.log_and_reject_error") @@ -107,7 +130,7 @@ def test_handle_json_error_with_error_response_object(self, mock_log_and_reject_ "http_code": 403, "http_status": "Forbidden", "grpc_code": 7, - "details": ["detail2"] + "details": ["detail2"], } } @@ -118,13 +141,7 @@ def test_handle_json_error_with_error_response_object(self, mock_log_and_reject_ handle_json_error(mock_error, mock_error_response, request_id, mock_logger) mock_log_and_reject_error.assert_called_once_with( - "ErrorResponse message", - 403, - request_id, - "Forbidden", - 7, - ["detail2"], - logger=mock_logger + "ErrorResponse message", 403, request_id, "Forbidden", 7, ["detail2"], logger=mock_logger ) def test_parse_path_params(self): @@ -138,13 +155,56 @@ def test_to_lowercase_keys(self): expected_output = {"key1": "value1", "key2": "value2"} self.assertEqual(to_lowercase_keys(input_dict), expected_output) + def test_r_urlencode_with_list_input(self): + pairs = {} + r_urlencode([], pairs, ["a", "b"]) + self.assertIn("[0]", pairs) + self.assertIn("[1]", pairs) + self.assertEqual(pairs["[0]"], "a") + self.assertEqual(pairs["[1]"], "b") + + def test_r_urlencode_with_tuple_input(self): + pairs = {} + r_urlencode([], pairs, ("x", "y")) + self.assertIn("[0]", pairs) + self.assertEqual(pairs["[0]"], "x") + def test_get_metrics(self): metrics = get_metrics() - self.assertIn('sdk_name_version', metrics) - self.assertIn('sdk_client_device_model', metrics) - self.assertIn('sdk_client_os_details', metrics) - self.assertIn('sdk_runtime_details', metrics) + self.assertIn("sdk_name_version", metrics) + self.assertIn("sdk_client_device_model", metrics) + self.assertIn("sdk_client_os_details", metrics) + self.assertIn("sdk_runtime_details", metrics) + + def test_get_metrics_platform_node_exception(self): + import skyflow.utils._utils as utils_module + + utils_module._CACHED_METRICS.clear() + with patch("skyflow.utils._utils.platform") as mock_platform: + mock_platform.node.side_effect = OSError("no node") + metrics = utils_module.get_metrics() + self.assertEqual(metrics["sdk_client_device_model"], "") + utils_module._CACHED_METRICS.clear() + + def test_get_metrics_sys_attribute_exception(self): + import skyflow.utils._utils as utils_module + + utils_module._CACHED_METRICS.clear() + + class _RaisingSys: + @property + def platform(self): + raise RuntimeError("no platform") + + @property + def version(self): + raise RuntimeError("no version") + with patch("skyflow.utils._utils.sys", _RaisingSys()): + metrics = utils_module.get_metrics() + self.assertEqual(metrics["sdk_client_os_details"], "") + self.assertIn("sdk_runtime_details", metrics) + utils_module._CACHED_METRICS.clear() def test_construct_invoke_connection_request_valid(self): mock_connection_request = Mock() @@ -164,7 +224,7 @@ def test_construct_invoke_connection_request_valid(self): self.assertEqual(result.url, expected_url) self.assertEqual(result.method, "POST") - self.assertEqual(result.headers['Content-Type'], ContentType.JSON.value) + self.assertEqual(result.headers["Content-Type"], ContentType.JSON.value) self.assertEqual(result.body, json.dumps(mock_connection_request.body)) @@ -230,9 +290,7 @@ def test_construct_invoke_connection_request_with_form_date_content_type(self): mock_connection_request = Mock() mock_connection_request.path_params = {"param1": "value1"} mock_connection_request.headers = {"Content-Type": ContentType.FORMDATA.value} - mock_connection_request.body = { - "name": (None, "John Doe") - } + mock_connection_request.body = {"name": (None, "John Doe")} mock_connection_request.method.value = "POST" mock_connection_request.query_params = {"query": "test"} @@ -242,13 +300,27 @@ def test_construct_invoke_connection_request_with_form_date_content_type(self): self.assertIsInstance(result, PreparedRequest) + def test_parse_insert_response_with_tokens_continue_on_error(self): + api_response = Mock() + api_response.headers = {"x-request-id": "req-1"} + api_response.data = Mock( + responses=[ + {"Status": 200, "Body": {"records": [{"skyflow_id": "id1", "tokens": {"col1": "tok1"}}]}}, + ] + ) + result = parse_insert_response(api_response, continue_on_error=True) + self.assertEqual(result.inserted_fields[0]["col1"], "tok1") + self.assertEqual(result.inserted_fields[0]["skyflow_id"], "id1") + def test_parse_insert_response(self): api_response = Mock() api_response.headers = {"x-request-id": "12345", "content-type": "application/json"} - api_response.data = Mock(responses=[ - {"Status": 200, "Body": {"records": [{"skyflow_id": "id1"}]}}, - {"Status": 400, "Body": {"error": TEST_ERROR_MESSAGE}} - ]) + api_response.data = Mock( + responses=[ + {"Status": 200, "Body": {"records": [{"skyflow_id": "id1"}]}}, + {"Status": 400, "Body": {"error": TEST_ERROR_MESSAGE}}, + ] + ) result = parse_insert_response(api_response, continue_on_error=True) self.assertEqual(len(result.inserted_fields), 1) self.assertEqual(len(result.errors), 1) @@ -262,17 +334,19 @@ def test_parse_insert_response(self): def test_parse_insert_response_continue_on_error_false(self): mock_api_response = Mock() mock_api_response.headers = {"x-request-id": "12345", "content-type": "application/json"} - mock_api_response.data = Mock(records=[ - Mock(skyflow_id="id_1", tokens={"token1": "token_value1"}), - Mock(skyflow_id="id_2", tokens={"token2": "token_value2"}) - ]) + mock_api_response.data = Mock( + records=[ + Mock(skyflow_id="id_1", tokens={"token1": "token_value1"}), + Mock(skyflow_id="id_2", tokens={"token2": "token_value2"}), + ] + ) result = parse_insert_response(mock_api_response, continue_on_error=False) self.assertIsInstance(result, InsertResponse) expected_inserted_fields = [ {"skyflow_id": "id_1", "token1": "token_value1"}, - {"skyflow_id": "id_2", "token2": "token_value2"} + {"skyflow_id": "id_2", "token2": "token_value2"}, ] self.assertEqual(result.inserted_fields, expected_inserted_fields) @@ -283,8 +357,8 @@ def test_parse_update_record_response(self): api_response.skyflow_id = "id1" api_response.tokens = {"token1": "value1"} result = parse_update_record_response(api_response) - self.assertEqual(result.updated_field['skyflow_id'], "id1") - self.assertEqual(result.updated_field['token1'], "value1") + self.assertEqual(result.updated_field["skyflow_id"], "id1") + self.assertEqual(result.updated_field["token1"], "value1") def test_parse_delete_response_successful(self): mock_api_response = Mock() @@ -302,42 +376,39 @@ def test_parse_delete_response_successful(self): def test_parse_get_response_successful(self): mock_api_response = Mock() mock_api_response.records = [ - Mock(fields={'field1': 'value1', 'field2': 'value2'}), - Mock(fields={'field1': 'value3', 'field2': 'value4'}) + Mock(fields={"field1": "value1", "field2": "value2"}), + Mock(fields={"field1": "value3", "field2": "value4"}), ] result = parse_get_response(mock_api_response) self.assertIsInstance(result, GetResponse) - expected_data = [ - {'field1': 'value1', 'field2': 'value2'}, - {'field1': 'value3', 'field2': 'value4'} - ] + expected_data = [{"field1": "value1", "field2": "value2"}, {"field1": "value3", "field2": "value4"}] self.assertEqual(result.data, expected_data) - # self.assertEqual(result.errors, None) + self.assertIsNone(result.errors) def test_parse_detokenize_response_with_mixed_records(self): mock_api_response = Mock() mock_api_response.headers = {"x-request-id": "12345", "content-type": "application/json"} - mock_api_response.data = Mock(records=[ - Mock(token="token1", value="value1", value_type="Type1", error=None), - Mock(token="token2", value=None, value_type=None, error="Some error"), - Mock(token="token3", value="value3", value_type="Type2", error=None), - ]) + mock_api_response.data = Mock( + records=[ + Mock(token="token1", value="value1", value_type="Type1", error=None), + Mock(token="token2", value=None, value_type=None, error="Some error"), + Mock(token="token3", value="value3", value_type="Type2", error=None), + ] + ) result = parse_detokenize_response(mock_api_response) self.assertIsInstance(result, DetokenizeResponse) expected_detokenized_fields = [ {"token": "token1", "value": "value1", "type": "Type1"}, - {"token": "token3", "value": "value3", "type": "Type2"} + {"token": "token3", "value": "value3", "type": "Type2"}, ] - expected_errors = [ - {"token": "token2", "error": "Some error", "request_id": "12345"} - ] + expected_errors = [{"token": "token2", "error": "Some error", "request_id": "12345"}] self.assertEqual(result.detokenized_fields, expected_detokenized_fields) self.assertEqual(result.errors, expected_errors) @@ -353,11 +424,7 @@ def test_parse_tokenize_response_with_valid_records(self): result = parse_tokenize_response(mock_api_response) self.assertIsInstance(result, TokenizeResponse) - expected_tokenized_fields = [ - {"token": "token1"}, - {"token": "token2"}, - {"token": "token3"} - ] + expected_tokenized_fields = [{"token": "token1"}, {"token": "token2"}, {"token": "token3"}] self.assertEqual(result.tokenized_fields, expected_tokenized_fields) @@ -365,7 +432,7 @@ def test_parse_query_response_with_valid_records(self): mock_api_response = Mock() mock_api_response.records = [ Mock(fields={"field1": "value1", "field2": "value2"}), - Mock(fields={"field1": "value3", "field2": "value4"}) + Mock(fields={"field1": "value3", "field2": "value4"}), ] result = parse_query_response(mock_api_response) @@ -374,7 +441,7 @@ def test_parse_query_response_with_valid_records(self): expected_fields = [ {"field1": "value1", "field2": "value2", "tokenized_data": {}}, - {"field1": "value3", "field2": "value4", "tokenized_data": {}} + {"field1": "value3", "field2": "value4", "tokenized_data": {}}, ] self.assertEqual(result.fields, expected_fields) @@ -382,7 +449,7 @@ def test_parse_query_response_with_valid_records(self): @patch("requests.Response") def test_parse_invoke_connection_response_successful(self, mock_response): mock_response.status_code = 200 - mock_response.content = json.dumps({"key": "value"}).encode('utf-8') + mock_response.content = json.dumps({"key": "value"}).encode("utf-8") mock_response.headers = {"x-request-id": "1234"} result = parse_invoke_connection_response(mock_response) @@ -394,19 +461,23 @@ def test_parse_invoke_connection_response_successful(self, mock_response): @patch("requests.Response") def test_parse_invoke_connection_response_json_decode_error(self, mock_response): - + """Test that non-JSON content in successful response is returned as string.""" mock_response.status_code = 200 - mock_response.content = "Non-JSON Content".encode('utf-8') + mock_response.content = "Non-JSON Content".encode("utf-8") + mock_response.headers = {"x-request-id": "1234"} + mock_response.raise_for_status = Mock() - with self.assertRaises(SkyflowError) as context: - parse_invoke_connection_response(mock_response) + result = parse_invoke_connection_response(mock_response) - self.assertEqual(context.exception.message, SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format("Non-JSON Content")) + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "Non-JSON Content") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) @patch("requests.Response") def test_parse_invoke_connection_response_http_error_with_json_error_message(self, mock_response): mock_response.status_code = 404 - mock_response.content = json.dumps({"error": {"message": "Not Found"}}).encode('utf-8') + mock_response.content = json.dumps({"error": {"message": "Not Found"}}).encode("utf-8") mock_response.headers = {"x-request-id": "1234"} mock_response.raise_for_status.side_effect = HTTPError("404 Error") @@ -417,10 +488,38 @@ def test_parse_invoke_connection_response_http_error_with_json_error_message(sel self.assertEqual(context.exception.message, "Not Found") self.assertEqual(context.exception.request_id, "1234") + @patch("requests.Response") + def test_parse_invoke_connection_response_with_error_from_client_header(self, mock_response): + from requests.models import HTTPError + + mock_response.status_code = 400 + mock_response.content = json.dumps( + { + "error": { + "message": "Client error", + "http_code": 400, + "http_status": "Bad Request", + "grpc_code": 3, + "details": None, + } + } + ).encode("utf-8") + mock_response.headers = { + "x-request-id": "rid-1", + "error-from-client": "true", + } + mock_response.raise_for_status.side_effect = HTTPError("400") + with self.assertRaises(SkyflowError) as context: + parse_invoke_connection_response(mock_response) + err = context.exception + self.assertEqual(err.message, "Client error") + self.assertIsNotNone(err.details) + self.assertTrue(any(d.get("error_from_client") is True for d in err.details)) + @patch("requests.Response") def test_parse_invoke_connection_response_http_error_without_json_error_message(self, mock_response): mock_response.status_code = 500 - mock_response.content = "Internal Server Error".encode('utf-8') + mock_response.content = "Internal Server Error".encode("utf-8") mock_response.headers = {"x-request-id": "1234"} mock_response.raise_for_status.side_effect = HTTPError("500 Error") @@ -428,37 +527,32 @@ def test_parse_invoke_connection_response_http_error_without_json_error_message( with self.assertRaises(SkyflowError) as context: parse_invoke_connection_response(mock_response) - self.assertEqual(context.exception.message, SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format("Internal Server Error")) + self.assertEqual(context.exception.message, SkyflowMessages.Error.API_ERROR.value.format(500)) + self.assertEqual(context.exception.http_code, 500) + self.assertEqual(context.exception.request_id, "1234") @patch("skyflow.utils._utils.log_and_reject_error") def test_handle_exception_json_error(self, mock_log_and_reject_error): mock_error = Mock() - mock_error.headers = { - 'x-request-id': '1234', - 'content-type': 'application/json' - } - mock_error.body = json.dumps({ - "error": { - "message": "JSON error occurred.", - "http_code": 400, - "http_status": "Bad Request", - "grpc_code": "8", - "details": "Detailed message" + mock_error.headers = {"x-request-id": "1234", "content-type": "application/json"} + mock_error.body = json.dumps( + { + "error": { + "message": "JSON error occurred.", + "http_code": 400, + "http_status": "Bad Request", + "grpc_code": "8", + "details": "Detailed message", + } } - }).encode('utf-8') + ).encode("utf-8") mock_logger = Mock() handle_exception(mock_error, mock_logger) mock_log_and_reject_error.assert_called_once_with( - "JSON error occurred.", - 400, - "1234", - "Bad Request", - "8", - "Detailed message", - logger=mock_logger + "JSON error occurred.", 400, "1234", "Bad Request", "8", "Detailed message", logger=mock_logger ) def test_validate_api_key_valid_key(self): @@ -494,12 +588,7 @@ def test_parse_deidentify_text_response(self): mock_entity.value = "sensitive_value" mock_entity.entity_type = "EMAIL" mock_entity.entity_scores = {"EMAIL": 0.95} - mock_entity.location = Mock( - start_index=10, - end_index=20, - start_index_processed=15, - end_index_processed=25 - ) + mock_entity.location = Mock(start_index=10, end_index=20, start_index_processed=15, end_index_processed=25) mock_api_response = Mock() mock_api_response.processed_text = "Sample processed text" @@ -556,10 +645,7 @@ def test__convert_detected_entity_to_entity_info(self): mock_detected_entity.entity_type = "EMAIL" mock_detected_entity.entity_scores = {"EMAIL": 0.95} mock_detected_entity.location = Mock( - start_index=10, - end_index=20, - start_index_processed=15, - end_index_processed=25 + start_index=10, end_index=20, start_index_processed=15, end_index_processed=25 ) result = convert_detected_entity_to_entity_info(mock_detected_entity) @@ -580,12 +666,7 @@ def test__convert_detected_entity_to_entity_info_with_minimal_data(self): mock_detected_entity.value = None mock_detected_entity.entity_type = "UNKNOWN" mock_detected_entity.entity_scores = {} - mock_detected_entity.location = Mock( - start_index=0, - end_index=0, - start_index_processed=0, - end_index_processed=0 - ) + mock_detected_entity.location = Mock(start_index=0, end_index=0, start_index_processed=0, end_index_processed=0) result = convert_detected_entity_to_entity_info(mock_detected_entity) @@ -597,3 +678,925 @@ def test__convert_detected_entity_to_entity_info_with_minimal_data(self): self.assertEqual(result.text_index.end, 0) self.assertEqual(result.processed_index.start, 0) self.assertEqual(result.processed_index.end, 0) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_connect_error(self, mock_log_and_reject_error): + """Test handling httpx.ConnectError.""" + import httpx + + mock_error = httpx.ConnectError("Connection refused") + mock_logger = Mock() + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + "Connection refused", SkyflowMessages.ErrorCodes.INVALID_INPUT.value, None, logger=mock_logger + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_no_headers_attribute(self, mock_log_and_reject_error): + """Test handling error without headers attribute.""" + mock_error = Exception("Generic error") + mock_logger = Mock() + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + "Generic error", SkyflowMessages.ErrorCodes.SERVER_ERROR.value, None, logger=mock_logger + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_no_body_attribute(self, mock_log_and_reject_error): + """Test handling error without body attribute.""" + mock_error = Mock() + mock_error.headers = {"x-request-id": "12345"} + delattr(mock_error, "body") + mock_logger = Mock() + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once() + self.assertEqual(mock_log_and_reject_error.call_args[0][1], SkyflowMessages.ErrorCodes.SERVER_ERROR.value) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_text_plain_error(self, mock_log_and_reject_error): + """Test handling text/plain content type error.""" + mock_error = Mock() + mock_error.headers = {"x-request-id": "1234", "content-type": "text/plain"} + mock_error.body = "Plain text error message" + mock_error.status = 500 + mock_logger = Mock() + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once_with("Plain text error message", 500, "1234", logger=mock_logger) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_generic_error_with_status(self, mock_log_and_reject_error): + """Test handling generic error with unknown content type.""" + mock_error = Mock() + mock_error.headers = {"x-request-id": "1234", "content-type": "application/xml"} + mock_error.body = "XML error" + mock_error.status = 503 + mock_logger = Mock() + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once_with(str(mock_error), 503, "1234", logger=mock_logger) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_no_content_type(self, mock_log_and_reject_error): + """Test handling error without content-type header.""" + mock_error = Mock() + mock_error.headers = {"x-request-id": "1234"} + mock_error.body = "Some error" + mock_error.status = 500 + mock_logger = Mock() + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once_with(str(mock_error), 500, "1234", logger=mock_logger) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_json_error_with_json_string(self, mock_log_and_reject_error): + """Test handling JSON error when data is a JSON string.""" + error_json_string = json.dumps( + { + "error": { + "message": "String JSON error", + "http_code": 422, + "http_status": "Unprocessable Entity", + "grpc_code": 3, + "details": ["validation failed"], + } + } + ) + + mock_error = Mock() + mock_logger = Mock() + request_id = "test-request-id-3" + + handle_json_error(mock_error, error_json_string, request_id, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + "String JSON error", 422, request_id, "Unprocessable Entity", 3, ["validation failed"], logger=mock_logger + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_json_error_with_invalid_json(self, mock_log_and_reject_error): + """Test handling JSON decode error.""" + invalid_json = "This is not valid JSON" + mock_error = Mock() + mock_error.status = 500 + mock_logger = Mock() + request_id = "test-request-id-4" + + handle_json_error(mock_error, invalid_json, request_id, mock_logger) + + # Should call with INVALID_JSON_RESPONSE error + mock_log_and_reject_error.assert_called_once() + self.assertEqual(mock_log_and_reject_error.call_args[0][0], SkyflowMessages.Error.INVALID_JSON_RESPONSE.value) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_json_error_missing_error_field(self, mock_log_and_reject_error): + """Test handling JSON error with missing error field.""" + error_dict = {"message": "Error without error wrapper"} + + mock_error = Mock() + mock_logger = Mock() + request_id = "test-request-id-5" + + handle_json_error(mock_error, error_dict, request_id, mock_logger) + + # Should use defaults for missing fields + mock_log_and_reject_error.assert_called_once() + args = mock_log_and_reject_error.call_args[0] + # Default message when error field is missing + self.assertEqual(args[0], SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value) + # Default status code + self.assertEqual(args[1], 500) + self.assertEqual(args[2], request_id) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_text_error_with_status(self, mock_log_and_reject_error): + """Test handle_text_error extracts status correctly.""" + mock_error = Mock() + mock_error.status = 404 + mock_logger = Mock() + request_id = "test-request-id-6" + error_data = "Resource not found" + + from skyflow.utils._utils import handle_text_error + + handle_text_error(mock_error, error_data, request_id, mock_logger) + + mock_log_and_reject_error.assert_called_once_with("Resource not found", 404, request_id, logger=mock_logger) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_generic_error_with_status(self, mock_log_and_reject_error): + """Test handle_generic_error_with_status.""" + mock_error = Mock() + mock_logger = Mock() + request_id = "test-request-id-7" + status = 503 + + from skyflow.utils._utils import handle_generic_error_with_status + + handle_generic_error_with_status(mock_error, request_id, status, mock_logger) + + mock_log_and_reject_error.assert_called_once_with(str(mock_error), 503, request_id, logger=mock_logger) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_with_none_error(self, mock_log_and_reject_error): + """Test handling None error object.""" + mock_logger = Mock() + + handle_exception(None, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + SkyflowMessages.Error.GENERIC_API_ERROR.value, + SkyflowMessages.ErrorCodes.SERVER_ERROR.value, + None, + logger=mock_logger, + ) + + # failed + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_with_empty_string_error(self, mock_log_and_reject_error): + """Test handling empty string error.""" + mock_logger = Mock() + mock_error = Mock() + mock_error.headers = None + mock_error.body = None + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once() + # Should use str(error) or default message + self.assertEqual(mock_log_and_reject_error.call_args[0][1], SkyflowMessages.ErrorCodes.SERVER_ERROR.value) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_json_error_with_responses_key(self, mock_log_and_reject_error): + """Test handle_json_error when body has 'responses' key (batch/continue_on_error path).""" + error_dict = { + "responses": [ + {"Status": 400, "Body": {"error": "record not found"}}, + {"Status": 400, "Body": {"error": "invalid field"}}, + ] + } + mock_error = Mock() + mock_logger = Mock() + request_id = "test-request-id-responses" + + handle_json_error(mock_error, error_dict, request_id, mock_logger) + + mock_log_and_reject_error.assert_called_once() + args = mock_log_and_reject_error.call_args[0] + self.assertIn("record not found", args[0]) + self.assertIn("invalid field", args[0]) + self.assertEqual(args[1], 400) + self.assertIsNone(args[3]) # http_status + self.assertIsNone(args[4]) # grpc_code + self.assertEqual(args[5], []) # details + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_json_error_responses_no_error_messages(self, mock_log_and_reject_error): + """Test handle_json_error with responses key but no error body — falls back to default message.""" + error_dict = { + "responses": [ + {"Status": 200, "Body": {"records": [{"skyflow_id": "abc"}]}}, + ] + } + mock_error = Mock() + request_id = "test-request-id-responses-empty" + + handle_json_error(mock_error, error_dict, request_id, None) + + mock_log_and_reject_error.assert_called_once() + args = mock_log_and_reject_error.call_args[0] + self.assertEqual(args[0], SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_json_error_with_bytes_data(self, mock_log_and_reject_error): + """Test handling JSON error when data is bytes.""" + error_dict = {"error": {"message": "Bytes error", "http_code": 401, "http_status": "Unauthorized"}} + error_bytes = json.dumps(error_dict).encode("utf-8") + + mock_error = Mock() + mock_logger = Mock() + request_id = "test-request-id-8" + + handle_json_error(mock_error, error_bytes, request_id, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + "Bytes error", 401, request_id, "Unauthorized", None, [], logger=mock_logger + ) + + # Add these new test methods to the TestUtils class: + + def test_construct_invoke_connection_request_with_no_headers(self): + """Test construct_invoke_connection_request when headers are None.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {"param1": "value1"} + mock_connection_request.headers = None + mock_connection_request.body = {"key": "value"} + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {"query": "test"} + + connection_url = "https://example.com/{param1}/endpoint" + + result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertIsInstance(result, PreparedRequest) + # Headers should be None when not provided + self.assertIsNone(result.headers.get("Content-Type")) + + def test_construct_invoke_connection_request_with_xml_content_type(self): + """Test construct_invoke_connection_request with XML content type.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": "application/xml"} + mock_connection_request.body = {"root": {"child": "value"}} + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertIsInstance(result, PreparedRequest) + self.assertEqual(result.headers["content-type"], "application/xml") + # Body should be converted to XML + self.assertIn("", result.body) + self.assertIn("value", result.body) + + def test_construct_invoke_connection_request_with_html_content_type(self): + """Test construct_invoke_connection_request with HTML content type.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": "text/html"} + mock_connection_request.body = {"message": "Hello"} + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertIsInstance(result, PreparedRequest) + self.assertEqual(result.headers["content-type"], "text/html") + # Body should be JSON string for HTML + self.assertEqual(result.body, json.dumps({"message": "Hello"})) + + def test_construct_invoke_connection_request_multipart_removes_content_type(self): + """Test that Content-Type is removed for multipart/form-data.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": ContentType.FORMDATA.value} + mock_connection_request.body = {"field1": "value1", "field2": "value2"} + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertIsInstance(result, PreparedRequest) + # Content-Type should be auto-generated by requests library + self.assertIn("multipart/form-data", result.headers.get("Content-Type", "")) + self.assertIn("boundary=", result.headers.get("Content-Type", "")) + + def test_construct_invoke_connection_request_with_no_body(self): + """Test construct_invoke_connection_request when body is None.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": ContentType.JSON.value} + mock_connection_request.body = None + mock_connection_request.method.value = "GET" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertIsInstance(result, PreparedRequest) + self.assertIsNone(result.body) + + def test_get_data_from_content_type_url_encoded(self): + """Test get_data_from_content_type with URL encoded content type.""" + from skyflow.utils._utils import get_data_from_content_type + + data = {"key1": "value1", "key2": "value2"} + content_type = ContentType.URLENCODED.value + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, "key1=value1&key2=value2") + self.assertEqual(files, {}) + + def test_get_data_from_content_type_form_data(self): + """Test get_data_from_content_type with form data content type.""" + from skyflow.utils._utils import get_data_from_content_type + + data = {"field1": "value1", "field2": "value2"} + content_type = ContentType.FORMDATA.value + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertIsNone(converted_data) + self.assertEqual(files["field1"], (None, "value1")) + self.assertEqual(files["field2"], (None, "value2")) + + def test_get_data_from_content_type_json(self): + """Test get_data_from_content_type with JSON content type.""" + from skyflow.utils._utils import get_data_from_content_type + + data = {"key": "value"} + content_type = ContentType.JSON.value + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, json.dumps(data)) + self.assertEqual(files, {}) + + def test_get_data_from_content_type_xml_with_dict(self): + """Test get_data_from_content_type with XML content type and dict data.""" + from skyflow.utils._utils import get_data_from_content_type + + data = {"root": {"child": "value"}} + content_type = "application/xml" + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertIn("", converted_data) + self.assertIn("value", converted_data) + self.assertEqual(files, {}) + + def test_get_data_from_content_type_xml_with_string(self): + """Test get_data_from_content_type with XML content type and string data.""" + from skyflow.utils._utils import get_data_from_content_type + + data = "value" + content_type = "text/xml" + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, data) + self.assertEqual(files, {}) + + def test_get_data_from_content_type_html_with_dict(self): + """Test get_data_from_content_type with HTML content type and dict data.""" + from skyflow.utils._utils import get_data_from_content_type + + data = {"message": "Hello"} + content_type = "text/html" + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, json.dumps(data)) + self.assertEqual(files, {}) + + def test_get_data_from_content_type_html_with_string(self): + """Test get_data_from_content_type with HTML content type and string data.""" + from skyflow.utils._utils import get_data_from_content_type + + data = "Hello" + content_type = "text/html" + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, data) + self.assertEqual(files, {}) + + def test_get_data_from_content_type_unknown_type_with_dict(self): + """Test get_data_from_content_type with unknown content type and dict data.""" + from skyflow.utils._utils import get_data_from_content_type + + data = {"key": "value"} + content_type = "application/custom" + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, json.dumps(data)) + self.assertEqual(files, {}) + + def test_get_data_from_content_type_unknown_type_with_string(self): + """Test get_data_from_content_type with unknown content type and string data.""" + from skyflow.utils._utils import get_data_from_content_type + + data = "plain text data" + content_type = "text/plain" + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, data) + self.assertEqual(files, {}) + + def test_dict_to_xml_simple_dict(self): + """Test dict_to_xml with simple dictionary.""" + from skyflow.utils._utils import dict_to_xml + + data = {"name": "John", "age": "30"} + result = dict_to_xml(data) + + self.assertIn("John", result) + self.assertIn("30", result) + self.assertTrue(result.startswith("")) + self.assertTrue(result.endswith("")) + + def test_dict_to_xml_nested_dict(self): + """Test dict_to_xml with nested dictionary.""" + from skyflow.utils._utils import dict_to_xml + + data = {"person": {"name": "John", "age": "30"}} + result = dict_to_xml(data) + + self.assertIn("", result) + self.assertIn("John", result) + self.assertIn("30", result) + + def test_dict_to_xml_with_list(self): + """Test dict_to_xml with list values.""" + from skyflow.utils._utils import dict_to_xml + + data = {"items": ["item1", "item2", "item3"]} + result = dict_to_xml(data) + + self.assertIn("item1", result) + self.assertIn("item2", result) + self.assertIn("item3", result) + + @patch("requests.Response") + def test_parse_invoke_connection_response_xml_content(self, mock_response): + """Test parsing XML response content.""" + mock_response.status_code = 200 + mock_response.content = b"success" + mock_response.headers = {"x-request-id": "1234", "content-type": "application/xml"} + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "success") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + @patch("requests.Response") + def test_parse_invoke_connection_response_url_encoded_content(self, mock_response): + """Test parsing URL encoded response content.""" + mock_response.status_code = 200 + mock_response.content = b"card_number=4111111111111111&cvv=123" + mock_response.headers = {"x-request-id": "1234", "content-type": "application/x-www-form-urlencoded"} + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "card_number=4111111111111111&cvv=123") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + @patch("requests.Response") + def test_parse_invoke_connection_response_html_content(self, mock_response): + """Test parsing HTML response content.""" + mock_response.status_code = 200 + mock_response.content = b"Success" + mock_response.headers = {"x-request-id": "1234", "content-type": "text/html"} + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "Success") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + @patch("requests.Response") + def test_parse_invoke_connection_response_html_error(self, mock_response): + """Test parsing HTML error response.""" + html_error = "

Error 500

" + mock_response.status_code = 500 + mock_response.content = html_error.encode("utf-8") + mock_response.headers = {"x-request-id": "1234", "content-type": "text/html"} + mock_response.raise_for_status = Mock(side_effect=HTTPError("500 Error")) + + with self.assertRaises(SkyflowError) as context: + parse_invoke_connection_response(mock_response) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.API_ERROR.value.format(500)) + self.assertEqual(context.exception.http_code, 500) + self.assertEqual(context.exception.request_id, "1234") + + @patch("requests.Response") + def test_parse_invoke_connection_response_json_decode_falls_back_to_string(self, mock_response): + """Test that JSON decode error falls back to returning string content.""" + mock_response.status_code = 200 + mock_response.content = b"Not valid JSON but still success" + mock_response.headers = {"x-request-id": "1234", "content-type": "application/json"} + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "Not valid JSON but still success") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + @patch("requests.Response") + def test_parse_invoke_connection_response_no_content_type_with_json(self, mock_response): + """Test parsing response with no content-type but valid JSON.""" + mock_response.status_code = 200 + mock_response.content = json.dumps({"success": True}).encode("utf-8") + mock_response.headers = {"x-request-id": "1234"} + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, {"success": True}) + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + @patch("requests.Response") + def test_parse_invoke_connection_response_no_content_type_with_text(self, mock_response): + """Test parsing response with no content-type and non-JSON content.""" + mock_response.status_code = 200 + mock_response.content = b"Plain text response" + mock_response.headers = {"x-request-id": "1234"} + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "Plain text response") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + @patch("requests.Response") + def test_parse_invoke_connection_response_bytes_content(self, mock_response): + """Test parsing response with bytes content.""" + mock_response.status_code = 200 + mock_response.content = b"Binary data response" + mock_response.headers = {"x-request-id": "1234", "content-type": "application/octet-stream"} + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "Binary data response") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + def test_construct_invoke_connection_request_headers_json_error(self): + """Test exception handling when json.dumps fails for headers.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + + class UnserializableObject: + def __repr__(self): + raise TypeError("Object is not JSON serializable") + + mock_connection_request.headers = {"key": UnserializableObject()} + mock_connection_request.body = None + mock_connection_request.method.value = "GET" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with patch("json.dumps", side_effect=TypeError("Object is not JSON serializable")): + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + def test_construct_invoke_connection_request_headers_generic_exception(self): + """Test generic exception handling for headers processing.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": "application/json"} + mock_connection_request.body = None + mock_connection_request.method.value = "GET" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with patch("skyflow.utils._utils.to_lowercase_keys", side_effect=Exception("Generic error")): + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + def test_construct_invoke_connection_request_body_processing_exception(self): + """Test exception handling when body processing fails.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": ContentType.JSON.value} + mock_connection_request.body = {"key": "value"} + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with patch("skyflow.utils._utils.get_data_from_content_type", side_effect=Exception("Body processing error")): + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + def test_construct_invoke_connection_request_body_json_dumps_exception(self): + """Test exception handling when json.dumps fails in get_data_from_content_type.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": ContentType.JSON.value} + + class UnserializableObject: + pass + + mock_connection_request.body = {"key": UnserializableObject()} + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + def test_construct_invoke_connection_request_invalid_url_exception(self): + """Test exception handling when requests.Request.prepare() fails with invalid URL.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = None + mock_connection_request.body = None + mock_connection_request.method.value = "GET" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with patch("requests.Request") as mock_request_class: + mock_request_instance = Mock() + mock_request_instance.prepare.side_effect = Exception("Invalid URL structure") + mock_request_class.return_value = mock_request_instance + + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_URL.value.format(connection_url)) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + def test_construct_invoke_connection_request_prepare_exception(self): + """Test exception handling when prepare() method fails.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": ContentType.JSON.value} + mock_connection_request.body = None + mock_connection_request.method.value = "GET" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with patch("requests.Request") as mock_request_class: + mock_request_instance = Mock() + mock_request_instance.prepare.side_effect = Exception("Prepare failed") + mock_request_class.return_value = mock_request_instance + + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_URL.value.format(connection_url)) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + def test_construct_invoke_connection_request_body_not_dict_raises_error(self): + """Test that non-dict body raises SkyflowError which is caught and re-raised.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": ContentType.JSON.value} + mock_connection_request.body = "not a dict" # Invalid body type + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + @patch("skyflow.utils._utils.validate_invoke_connection_params") + def test_construct_invoke_connection_request_validation_exception(self, mock_validate): + """Test that validation exceptions are properly propagated.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {"param": "value"} + mock_connection_request.headers = None + mock_connection_request.body = None + mock_connection_request.method.value = "GET" + mock_connection_request.query_params = {"query": "value"} + + connection_url = "https://example.com/endpoint" + + mock_validate.side_effect = SkyflowError("Validation failed", 400) + + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, "Validation failed") + self.assertEqual(context.exception.http_code, 400) + + def test_generate_bearer_token_invalid_token_uri_type(self): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": 12345, # invalid type + } + + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp: + json.dump(creds, tmp) + tmp.flush() + with self.assertRaises(SkyflowError) as context: + generate_bearer_token(tmp.name) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_bearer_token_invalid_token_uri_url(self): + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": "not_a_url"} + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp: + json.dump(creds, tmp) + tmp.flush() + with self.assertRaises(SkyflowError) as context: + generate_bearer_token(tmp.name) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_bearer_token_options_override_token_uri(self): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", + } + options = {"token_uri": "https://another-valid-url.com"} + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp: + json.dump(creds, tmp) + tmp.flush() + # Patch AuthClient and jwt.encode to avoid real HTTP and signing + with patch("skyflow.service_account._utils.get_signed_jwt") as mock_get_signed_jwt: + mock_get_signed_jwt.return_value = "signed" + with patch("skyflow.service_account._utils.AuthClient") as mock_auth_client: + mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value + mock_auth_api.authentication_service_get_auth_token.return_value = type( + "obj", (), {"access_token": "token", "token_type": "bearer"} + ) + generate_bearer_token(tmp.name, options) + args, kwargs = mock_get_signed_jwt.call_args + self.assertEqual(args[3], options["token_uri"]) + + def test_generate_bearer_token_from_creds_invalid_token_uri_type(self): + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": 12345} + creds_str = json.dumps(creds) + with self.assertRaises(SkyflowError) as context: + generate_bearer_token_from_creds(creds_str) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_bearer_token_from_creds_invalid_token_uri_url(self): + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": "not_a_url"} + creds_str = json.dumps(creds) + with self.assertRaises(SkyflowError) as context: + generate_bearer_token_from_creds(creds_str) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_bearer_token_from_creds_options_override_token_uri(self): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", + } + options = {"token_uri": "https://another-valid-url.com"} + creds_str = json.dumps(creds) + with patch("skyflow.service_account._utils.get_signed_jwt") as mock_get_signed_jwt: + mock_get_signed_jwt.return_value = "signed" + with patch("skyflow.service_account._utils.AuthClient") as mock_auth_client: + mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value + mock_auth_api.authentication_service_get_auth_token.return_value = type( + "obj", (), {"access_token": "token", "token_type": "bearer"} + ) + generate_bearer_token_from_creds(creds_str, options) + args, kwargs = mock_get_signed_jwt.call_args + self.assertEqual(args[3], options["token_uri"]) + + def test_generate_signed_data_tokens_invalid_token_uri_type(self): + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": 12345} + options = {"data_tokens": ["token1"]} + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp: + json.dump(creds, tmp) + tmp.flush() + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens(tmp.name, options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_signed_data_tokens_invalid_token_uri_url(self): + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": "not_a_url"} + options = {"data_tokens": ["token1"]} + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp: + json.dump(creds, tmp) + tmp.flush() + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens(tmp.name, options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_signed_data_tokens_from_creds_invalid_token_uri_type(self): + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": 12345} + options = {"data_tokens": ["token1"]} + creds_str = json.dumps(creds) + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens_from_creds(creds_str, options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_signed_data_tokens_from_creds_invalid_token_uri_url(self): + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": "not_a_url"} + options = {"data_tokens": ["token1"]} + creds_str = json.dumps(creds) + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens_from_creds(creds_str, options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_signed_data_tokens_options_override_token_uri(self): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", + } + options = {"data_tokens": ["token1"], "token_uri": "https://another-valid-url.com"} + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp: + json.dump(creds, tmp) + tmp.flush() + with patch("jwt.encode") as mock_jwt_encode: + mock_jwt_encode.return_value = "signed" + result = generate_signed_data_tokens(tmp.name, options) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + self.assertEqual(result[0][0], "token1") + self.assertEqual(result[0][1], "signed_token_signed") + + def test_generate_signed_data_tokens_from_creds_options_override_token_uri(self): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", + } + options = {"data_tokens": ["token1"], "token_uri": "https://another-valid-url.com"} + creds_str = json.dumps(creds) + with patch("jwt.encode") as mock_jwt_encode: + mock_jwt_encode.return_value = "signed" + result = generate_signed_data_tokens_from_creds(creds_str, options) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + self.assertEqual(result[0][0], "token1") + self.assertEqual(result[0][1], "signed_token_signed") diff --git a/tests/utils/validations/test__validations.py b/tests/utils/validations/test__validations.py index 48332a55..c5ad6b79 100644 --- a/tests/utils/validations/test__validations.py +++ b/tests/utils/validations/test__validations.py @@ -12,13 +12,15 @@ validate_insert_request, validate_delete_request, validate_query_request, validate_get_detect_run_request, validate_get_request, validate_update_request, validate_detokenize_request, validate_tokenize_request, validate_invoke_connection_params, - validate_deidentify_text_request, validate_reidentify_text_request, validate_deidentify_file_request + validate_deidentify_text_request, validate_reidentify_text_request, validate_deidentify_file_request, + validate_file_upload_request ) from skyflow.utils import SkyflowMessages from skyflow.utils.enums import DetectEntities, RedactionType from skyflow.vault.data import GetRequest, UpdateRequest from skyflow.vault.detect import DeidentifyTextRequest, Transformations, DateTransformation, ReidentifyTextRequest, \ - FileInput, DeidentifyFileRequest + FileInput, DeidentifyFileRequest, Bleep +from skyflow.vault.data._file_upload_request import FileUploadRequest from skyflow.vault.tokens import DetokenizeRequest from skyflow.vault.connection._invoke_connection_request import InvokeConnectionRequest @@ -116,7 +118,7 @@ def test_validate_credentials_with_expired_token(self): with patch('skyflow.service_account.is_expired', return_value=True): with self.assertRaises(SkyflowError) as context: validate_credentials(self.logger, credentials) - self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value) + self.assertEqual(context.exception.message, SkyflowMessages.Error.EXPIRED_BEARER_TOKEN.value) def test_validate_credentials_empty_credentials(self): credentials = {} @@ -205,15 +207,6 @@ def test_validate_update_vault_config_valid(self): } self.assertTrue(validate_update_vault_config(self.logger, config)) - def test_validate_update_vault_config_missing_credentials(self): - config = { - "vault_id": "vault123", - "cluster_id": "cluster123" - } - with self.assertRaises(SkyflowError) as context: - validate_update_vault_config(self.logger, config) - self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("vault", "vault123")) - def test_validate_update_vault_config_invalid_cluster_id(self): config = { "vault_id": "vault123", @@ -226,6 +219,18 @@ def test_validate_update_vault_config_invalid_cluster_id(self): validate_update_vault_config(self.logger, config) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CLUSTER_ID.value.format("vault123")) + def test_validate_update_vault_config_missing_credentials(self): + config = { + "vault_id": "vault123", + "cluster_id": "cluster123", + } + with self.assertRaises(SkyflowError) as context: + validate_update_vault_config(self.logger, config) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("vault", "vault123") + ) + def test_validate_connection_config_valid(self): config = { "connection_id": "conn123", @@ -259,6 +264,18 @@ def test_validate_connection_config_empty_connection_id(self): validate_connection_config(self.logger, config) self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_CONNECTION_ID.value) + def test_validate_connection_config_missing_credentials(self): + config = { + "connection_id": "conn123", + "connection_url": "https://example.com", + } + with self.assertRaises(SkyflowError) as context: + validate_connection_config(self.logger, config) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("connection", "conn123") + ) + def test_validate_update_connection_config_valid(self): config = { "connection_id": "conn123", @@ -1040,7 +1057,436 @@ def test_validate_detokenize_request_invalid_continue_on_error_type(self): self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CONTINUE_ON_ERROR_TYPE.value) def test_validate_detokenize_request_invalid_redaction_type(self): - request = DetokenizeRequest(data=[{"token": "token123", "redaction": "invalid"}], continue_on_error=False) + request = DetokenizeRequest(data=[{"token": "token123", "redaction_type": "invalid"}], continue_on_error=False) with self.assertRaises(SkyflowError) as context: validate_detokenize_request(self.logger, request) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REDACTION_TYPE.value.format(str(type("invalid")))) + + def test_validate_detokenize_request_deprecated_redaction_key_emits_warn(self): + from unittest.mock import patch + request = DetokenizeRequest(data=[{"token": "token123", "redaction": RedactionType.PLAIN_TEXT}], continue_on_error=False) + with patch('skyflow.utils.validations._validations.log_warn') as mock_warn: + validate_detokenize_request(self.logger, request) + mock_warn.assert_called_once() + self.assertIn("redaction_type", mock_warn.call_args[0][0]) + + def test_validate_detokenize_request_both_keys_prioritizes_redaction_type_and_warns(self): + from unittest.mock import patch + request = DetokenizeRequest( + data=[{"token": "token123", "redaction": RedactionType.PLAIN_TEXT, "redaction_type": RedactionType.MASKED}], + continue_on_error=False + ) + with patch('skyflow.utils.validations._validations.log_warn') as mock_warn: + validate_detokenize_request(self.logger, request) + mock_warn.assert_called_once() + + def test_validate_detokenize_request_redaction_type_only_no_warn(self): + from unittest.mock import patch + request = DetokenizeRequest(data=[{"token": "token123", "redaction_type": RedactionType.PLAIN_TEXT}], continue_on_error=False) + with patch('skyflow.utils.validations._validations.log_warn') as mock_warn: + validate_detokenize_request(self.logger, request) + mock_warn.assert_not_called() + + + def test_validate_deidentify_file_request_wait_time_negative(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + wait_time=-1, + entities=[DetectEntities.SSN] + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value) + + def test_validate_deidentify_file_request_wait_time_greater_than_64(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + wait_time=65, + entities=[DetectEntities.SSN] + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value) + + def test_validate_deidentify_file_request_wait_time_valid_boundary_lower(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + wait_time=0, + entities=[DetectEntities.SSN] + ) + validate_deidentify_file_request(self.logger, request) + + def test_validate_deidentify_file_request_wait_time_valid_boundary_upper(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + wait_time=64, + entities=[DetectEntities.SSN] + ) + # Should not raise an error + validate_deidentify_file_request(self.logger, request) + + def test_validate_deidentify_file_request_wait_time_valid_float(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + wait_time=32.5, + entities=[DetectEntities.SSN] + ) + # Should not raise an error + validate_deidentify_file_request(self.logger, request) + + def test_validate_deidentify_file_request_wait_time_float_out_of_range(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + wait_time=64.1, + entities=[DetectEntities.SSN] + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value) + def test_validate_credentials_with_valid_token_uri(self): + credentials = { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef", + "token_uri": "https://valid-url.com" + } + # Should not raise + validate_credentials(self.logger, credentials) + + def test_validate_credentials_with_invalid_token_uri_type(self): + credentials = { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef", + "token_uri": 12345 # Not a string + } + with self.assertRaises(SkyflowError) as context: + validate_credentials(self.logger, credentials) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_validate_credentials_with_invalid_token_uri_url(self): + credentials = { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef", + "token_uri": "not_a_url" + } + with self.assertRaises(SkyflowError) as context: + validate_credentials(self.logger, credentials) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_validate_update_vault_config_with_valid_token_uri(self): + from skyflow.utils.enums import Env + config = { + "vault_id": "vault123", + "cluster_id": "cluster123", + "credentials": { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef", + "token_uri": "https://valid-url.com" + }, + "env": Env.DEV + } + # Should not raise + self.assertTrue(validate_update_vault_config(self.logger, config)) + + def test_validate_update_vault_config_with_invalid_token_uri_type(self): + config = { + "vault_id": "vault123", + "cluster_id": "cluster123", + "credentials": { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef", + "token_uri": 12345 + } + } + with self.assertRaises(SkyflowError) as context: + validate_update_vault_config(self.logger, config) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_validate_update_vault_config_with_invalid_token_uri_url(self): + config = { + "vault_id": "vault123", + "cluster_id": "cluster123", + "credentials": { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef", + "token_uri": "not_a_url" + } + } + with self.assertRaises(SkyflowError) as context: + validate_update_vault_config(self.logger, config) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + # --- validate_file_from_request --- + + def test_validate_file_from_request_none_input(self): + with self.assertRaises(SkyflowError) as context: + validate_file_from_request(None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_INPUT.value) + + def test_validate_file_from_request_file_without_name_attr(self): + file_obj = MagicMock(spec=[]) # no attributes at all + file_input = MagicMock() + file_input.file = file_obj + file_input.file_path = None + with self.assertRaises(SkyflowError) as context: + validate_file_from_request(file_input) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_TYPE.value) + + def test_validate_file_from_request_file_with_empty_name(self): + file_obj = MagicMock() + file_obj.name = " " # whitespace-only name + file_input = MagicMock() + file_input.file = file_obj + file_input.file_path = None + with self.assertRaises(SkyflowError) as context: + validate_file_from_request(file_input) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_TYPE.value) + + def test_validate_file_from_request_extension_only_name(self): + file_obj = MagicMock() + # A trailing-slash path gives os.path.basename() == "", so splitext returns ("", "") + file_obj.name = "/some/directory/" + file_input = MagicMock() + file_input.file = file_obj + file_input.file_path = None + with self.assertRaises(SkyflowError) as context: + validate_file_from_request(file_input) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_NAME.value) + + def test_validate_file_from_request_empty_string_file_path(self): + file_input = MagicMock() + file_input.file = None + file_input.file_path = "" # empty string — has_file_path=True, so goes to elif branch + with self.assertRaises(SkyflowError) as context: + validate_file_from_request(file_input) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_PATH.value) + + # --- validate_deidentify_file_request bleep sub-fields --- + + def test_validate_deidentify_file_request_invalid_bleep_type(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest(file=file_input, bleep="not_a_bleep") + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_BLEEP_TYPE.value) + + def test_validate_deidentify_file_request_invalid_bleep_gain(self): + file_input = FileInput(file_path=self.temp_file_path) + bleep = Bleep(gain="loud") + request = DeidentifyFileRequest(file=file_input, bleep=bleep) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_BLEEP_GAIN.value) + + def test_validate_deidentify_file_request_invalid_bleep_frequency(self): + file_input = FileInput(file_path=self.temp_file_path) + bleep = Bleep(frequency="high") + request = DeidentifyFileRequest(file=file_input, bleep=bleep) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_BLEEP_FREQUENCY.value) + + def test_validate_deidentify_file_request_invalid_bleep_start_padding(self): + file_input = FileInput(file_path=self.temp_file_path) + bleep = Bleep(start_padding="early") + request = DeidentifyFileRequest(file=file_input, bleep=bleep) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_BLEEP_START_PADDING.value) + + def test_validate_deidentify_file_request_invalid_bleep_stop_padding(self): + file_input = FileInput(file_path=self.temp_file_path) + bleep = Bleep(stop_padding="late") + request = DeidentifyFileRequest(file=file_input, bleep=bleep) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_BLEEP_STOP_PADDING.value) + + # --- validate_deidentify_file_request output_directory --- + + def test_validate_deidentify_file_request_invalid_output_directory_type(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest(file=file_input, output_directory=123) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_OUTPUT_DIRECTORY_VALUE.value) + + def test_validate_deidentify_file_request_output_directory_not_found(self): + file_input = FileInput(file_path=self.temp_file_path) + nonexistent = "/tmp/skyflow_nonexistent_dir_12345" + request = DeidentifyFileRequest(file=file_input, output_directory=nonexistent) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.OUTPUT_DIRECTORY_NOT_FOUND.value.format(nonexistent) + ) + + def test_validate_deidentify_file_request_valid_output_directory(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest(file=file_input, output_directory=self.temp_dir_path) + validate_deidentify_file_request(self.logger, request) + + # --- validate_file_upload_request --- + + def test_validate_file_upload_request_none(self): + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TABLE_VALUE.value) + + def test_validate_file_upload_request_none_table(self): + request = MagicMock() + request.table = None + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TABLE_VALUE.value) + + def test_validate_file_upload_request_empty_table(self): + request = MagicMock() + request.table = " " + request.column_name = "file_col" + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_TABLE_VALUE.value) + + def test_validate_file_upload_request_none_column_name(self): + request = MagicMock() + request.table = "test_table" + request.skyflow_id = None + request.column_name = None + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.INVALID_FILE_COLUMN_NAME.value.format(type(None)) + ) + + def test_validate_file_upload_request_empty_column_name(self): + request = MagicMock() + request.table = "test_table" + request.skyflow_id = None + request.column_name = "" + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.INVALID_FILE_COLUMN_NAME.value.format(type("")) + ) + + def test_validate_file_upload_request_empty_skyflow_id(self): + request = FileUploadRequest( + table="test_table", + column_name="file_col", + skyflow_id=" ", + file_path=self.temp_file_path + ) + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.EMPTY_SKYFLOW_ID.value.format("FILE_UPLOAD") + ) + + def test_validate_file_upload_request_invalid_file_object_seek(self): + file_obj = MagicMock() + file_obj.seek.side_effect = OSError("seek failed") + request = FileUploadRequest( + table="test_table", + column_name="file_col", + file_object=file_obj + ) + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_OBJECT.value) + + def test_validate_file_upload_request_valid_file_path(self): + request = FileUploadRequest( + table="test_table", + column_name="file_col", + file_path=self.temp_file_path + ) + validate_file_upload_request(self.logger, request) + + def test_validate_file_upload_request_invalid_file_path(self): + request = FileUploadRequest( + table="test_table", + column_name="file_col", + file_path="/nonexistent/path/file.txt" + ) + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_PATH.value) + + def test_validate_file_upload_request_valid_base64(self): + import base64 + encoded = base64.b64encode(b"file content").decode("utf-8") + request = FileUploadRequest( + table="test_table", + column_name="file_col", + base64=encoded, + file_name="sample.txt" + ) + validate_file_upload_request(self.logger, request) + + def test_validate_file_upload_request_base64_without_file_name(self): + import base64 + encoded = base64.b64encode(b"file content").decode("utf-8") + request = FileUploadRequest( + table="test_table", + column_name="file_col", + base64=encoded + ) + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_NAME.value) + + def test_validate_file_upload_request_invalid_base64_string(self): + request = FileUploadRequest( + table="test_table", + column_name="file_col", + base64="not-valid-base64!!!", + file_name="sample.txt" + ) + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_BASE64_STRING.value) + + def test_validate_file_upload_request_valid_file_object(self): + with open(self.temp_file_path, "rb") as f: + request = FileUploadRequest( + table="test_table", + column_name="file_col", + file_object=f + ) + validate_file_upload_request(self.logger, request) + + def test_validate_file_upload_request_missing_file_source(self): + request = FileUploadRequest( + table="test_table", + column_name="file_col" + ) + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.MISSING_FILE_SOURCE.value) + + # --- validate_deidentify_text_request transformations --- + + def test_validate_deidentify_text_request_invalid_transformations(self): + request = DeidentifyTextRequest( + text="test text", + transformations="invalid_type" + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_text_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TRANSFORMATIONS.value) + + # --- validate_reidentify_text_request masked_entities --- + + def test_validate_reidentify_text_request_invalid_masked_entities(self): + request = ReidentifyTextRequest( + text="test text", + masked_entities="invalid_type" + ) + with self.assertRaises(SkyflowError) as context: + validate_reidentify_text_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.INVALID_MASKED_ENTITIES_IN_REIDENTIFY.value) diff --git a/tests/vault/client/test__client.py b/tests/vault/client/test__client.py index 565b1e6f..75826128 100644 --- a/tests/vault/client/test__client.py +++ b/tests/vault/client/test__client.py @@ -1,5 +1,8 @@ import unittest from unittest.mock import patch, MagicMock + +from skyflow.error import SkyflowError +from skyflow.utils import SkyflowMessages from skyflow.vault.client.client import VaultClient CONFIG = { @@ -12,11 +15,19 @@ } CREDENTIALS_WITH_API_KEY = {"api_key": "dummy_api_key"} +CREDENTIALS_WITH_TOKEN = {"token": "dummy_static_token"} +CREDENTIALS_WITH_PATH = {"path": "/some/path/credentials.json"} +CREDENTIALS_WITH_STRING = {"credentials_string": '{"clientID": "x"}'} + class TestVaultClient(unittest.TestCase): def setUp(self): self.vault_client = VaultClient(CONFIG) + # ------------------------------------------------------------------ # + # Basic setters / getters # + # ------------------------------------------------------------------ # + def test_set_common_skyflow_credentials(self): credentials = {"api_key": "dummy_api_key"} self.vault_client.set_common_skyflow_credentials(credentials) @@ -28,73 +39,289 @@ def test_set_logger(self): self.assertEqual(self.vault_client.get_log_level(), "INFO") self.assertEqual(self.vault_client.get_logger(), mock_logger) + def test_get_vault_id(self): + self.assertEqual(self.vault_client.get_vault_id(), CONFIG["vault_id"]) + + def test_get_config(self): + self.assertEqual(self.vault_client.get_config(), CONFIG) + + def test_get_common_skyflow_credentials(self): + credentials = {"api_key": "dummy_api_key"} + self.vault_client.set_common_skyflow_credentials(credentials) + self.assertEqual(self.vault_client.get_common_skyflow_credentials(), credentials) + + def test_get_log_level(self): + self.vault_client.set_logger("DEBUG", MagicMock()) + self.assertEqual(self.vault_client.get_log_level(), "DEBUG") + + def test_get_logger(self): + mock_logger = MagicMock() + self.vault_client.set_logger("INFO", mock_logger) + self.assertEqual(self.vault_client.get_logger(), mock_logger) + + # ------------------------------------------------------------------ # + # initialize_client_configuration — first call (slow path) # + # ------------------------------------------------------------------ # + @patch("skyflow.vault.client.client.get_credentials") @patch("skyflow.vault.client.client.get_vault_url") @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") - def test_initialize_client_configuration(self, mock_init_api_client, mock_get_vault_url, mock_get_credentials): - mock_get_credentials.return_value = (CREDENTIALS_WITH_API_KEY) + def test_initialize_client_configuration_first_call( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials + ): + mock_get_credentials.return_value = CREDENTIALS_WITH_API_KEY mock_get_vault_url.return_value = "https://test-vault-url.com" self.vault_client.initialize_client_configuration() - mock_get_credentials.assert_called_once_with(CONFIG["credentials"], None, logger=None) - mock_get_vault_url.assert_called_once_with(CONFIG["cluster_id"], CONFIG["env"], CONFIG["vault_id"], logger=None) + mock_get_credentials.assert_called_once_with( + CONFIG["credentials"], None, logger=None + ) + mock_get_vault_url.assert_called_once_with( + CONFIG["cluster_id"], CONFIG["env"], CONFIG["vault_id"], logger=None + ) mock_init_api_client.assert_called_once() - @patch("skyflow.vault.client.client.Skyflow") - def test_initialize_api_client(self, mock_api_client): - self.vault_client.initialize_api_client("https://test-vault-url.com", "dummy_token") - mock_api_client.assert_called_once_with(base_url="https://test-vault-url.com", token="dummy_token") + # ------------------------------------------------------------------ # + # initialize_client_configuration — fast path (static token) # + # ------------------------------------------------------------------ # - def test_get_records_api(self): + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_fast_path_api_key( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials + ): + """Once initialized with api_key, subsequent calls skip all work.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_API_KEY + mock_get_vault_url.return_value = "https://test-vault-url.com" + # Side-effect simulates initialize_api_client actually setting __api_client + mock_init_api_client.side_effect = lambda *_: setattr( + self.vault_client, "_VaultClient__api_client", MagicMock() + ) + + self.vault_client.initialize_client_configuration() # first call — slow path + mock_get_credentials.reset_mock() + mock_get_vault_url.reset_mock() + mock_init_api_client.reset_mock() + + self.vault_client.initialize_client_configuration() # second call — fast path + + mock_get_credentials.assert_not_called() + mock_get_vault_url.assert_not_called() + mock_init_api_client.assert_not_called() + + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_fast_path_static_token( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials + ): + """Once initialized with a static token, subsequent calls skip all work.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_TOKEN + mock_get_vault_url.return_value = "https://test-vault-url.com" + mock_init_api_client.side_effect = lambda *_: setattr( + self.vault_client, "_VaultClient__api_client", MagicMock() + ) + + self.vault_client.initialize_client_configuration() + mock_get_credentials.reset_mock() + mock_get_vault_url.reset_mock() + mock_init_api_client.reset_mock() + + self.vault_client.initialize_client_configuration() + + mock_get_credentials.assert_not_called() + mock_get_vault_url.assert_not_called() + mock_init_api_client.assert_not_called() + + # ------------------------------------------------------------------ # + # initialize_client_configuration — fast path (service account) # + # ------------------------------------------------------------------ # + + @patch("skyflow.vault.client.client.is_expired", return_value=False) + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_fast_path_valid_sa_token( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials, mock_is_expired + ): + """Service account with a still-valid token skips get_bearer_token entirely.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_PATH + mock_get_vault_url.return_value = "https://test-vault-url.com" + + # Seed the cached bearer token as if first call already ran self.vault_client._VaultClient__api_client = MagicMock() - self.vault_client._VaultClient__api_client.records = MagicMock() - records_api = self.vault_client.get_records_api() - self.assertIsNotNone(records_api) + self.vault_client._VaultClient__is_static_token = False + self.vault_client._VaultClient__bearer_token = "cached_sa_token" + self.vault_client._VaultClient__credentials = CREDENTIALS_WITH_PATH - def test_get_tokens_api(self): + self.vault_client.initialize_client_configuration() + + mock_get_credentials.assert_not_called() + mock_get_vault_url.assert_not_called() + mock_init_api_client.assert_not_called() + + # ------------------------------------------------------------------ # + # initialize_client_configuration — token expiry (no client reinit) # + # ------------------------------------------------------------------ # + + @patch("skyflow.vault.client.client.generate_bearer_token", return_value=("new_sa_token", None)) + @patch("skyflow.vault.client.client.is_expired", return_value=True) + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_expired_token_no_reinit( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials, + mock_is_expired, mock_generate_bearer_token + ): + """Expired service account token is regenerated in-place; httpx client is NOT recreated.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_PATH + mock_get_vault_url.return_value = "https://test-vault-url.com" + + # Client already initialized — simulate warm state with an expired token self.vault_client._VaultClient__api_client = MagicMock() - self.vault_client._VaultClient__api_client.tokens = MagicMock() - tokens_api = self.vault_client.get_tokens_api() - self.assertIsNotNone(tokens_api) + self.vault_client._VaultClient__is_static_token = False + self.vault_client._VaultClient__bearer_token = "expired_sa_token" + self.vault_client._VaultClient__credentials = CREDENTIALS_WITH_PATH - def test_get_query_api(self): + self.vault_client.initialize_client_configuration() + + # Token was regenerated + mock_generate_bearer_token.assert_called_once() + self.assertEqual( + self.vault_client._VaultClient__bearer_token, "new_sa_token" + ) + # httpx client was NOT recreated + mock_init_api_client.assert_not_called() + + # ------------------------------------------------------------------ # + # initialize_client_configuration — config update forces reinit # + # ------------------------------------------------------------------ # + + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_reinit_after_update_config( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials + ): + """update_config() marks the client stale; next call must recreate it.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_API_KEY + mock_get_vault_url.return_value = "https://test-vault-url.com" + + # Simulate already-initialized client self.vault_client._VaultClient__api_client = MagicMock() - self.vault_client._VaultClient__api_client.query = MagicMock() - query_api = self.vault_client.get_query_api() - self.assertIsNotNone(query_api) + self.vault_client._VaultClient__is_static_token = True - def test_get_vault_id(self): - self.assertEqual(self.vault_client.get_vault_id(), CONFIG["vault_id"]) + self.vault_client.update_config({"cluster_id": "new_cluster"}) + self.vault_client.initialize_client_configuration() + + mock_get_credentials.assert_called_once() + mock_get_vault_url.assert_called_once() + mock_init_api_client.assert_called_once() + + # ------------------------------------------------------------------ # + # initialize_api_client — lambda token provider # + # ------------------------------------------------------------------ # + + @patch("skyflow.vault.client.client.Skyflow") + def test_initialize_api_client_passes_callable_token(self, mock_skyflow): + """initialize_api_client must pass a callable (lambda) as token, not a string.""" + self.vault_client.initialize_api_client("https://test-vault-url.com", "initial_token") + + args, kwargs = mock_skyflow.call_args + self.assertEqual(kwargs["base_url"], "https://test-vault-url.com") + self.assertTrue(callable(kwargs["token"]), "token must be a callable (lambda)") + + @patch("skyflow.vault.client.client.Skyflow") + def test_initialize_api_client_lambda_returns_cached_bearer_token(self, mock_skyflow): + """Lambda returns __bearer_token when it is set (interceptor behaviour).""" + self.vault_client._VaultClient__bearer_token = "refreshed_token" + self.vault_client.initialize_api_client("https://test-vault-url.com", "initial_token") + + _, kwargs = mock_skyflow.call_args + self.assertEqual(kwargs["token"](), "refreshed_token") + + @patch("skyflow.vault.client.client.Skyflow") + def test_initialize_api_client_lambda_falls_back_to_initial_token(self, mock_skyflow): + """Lambda falls back to the initial token when __bearer_token is None.""" + self.vault_client._VaultClient__bearer_token = None + self.vault_client.initialize_api_client("https://test-vault-url.com", "initial_token") + + _, kwargs = mock_skyflow.call_args + self.assertEqual(kwargs["token"](), "initial_token") + + # ------------------------------------------------------------------ # + # get_bearer_token # + # ------------------------------------------------------------------ # + + def test_get_bearer_token_with_api_key(self): + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_API_KEY) + self.assertEqual(result, "dummy_api_key") + + def test_get_bearer_token_with_static_token(self): + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_TOKEN) + self.assertEqual(result, "dummy_static_token") + + @patch("skyflow.vault.client.client.generate_bearer_token", return_value=("sa_token", None)) + def test_get_bearer_token_generates_from_path_on_first_call(self, mock_generate): + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_PATH) + mock_generate.assert_called_once() + self.assertEqual(result, "sa_token") + self.assertEqual(self.vault_client._VaultClient__bearer_token, "sa_token") + + @patch("skyflow.vault.client.client.generate_bearer_token_from_creds", return_value=("sa_token_str", None)) + @patch("skyflow.vault.client.client.log_info") + def test_get_bearer_token_generates_from_credentials_string(self, mock_log, mock_generate): + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_STRING) + mock_generate.assert_called_once() + self.assertEqual(result, "sa_token_str") + + @patch("skyflow.vault.client.client.generate_bearer_token", return_value=("new_token", None)) + @patch("skyflow.vault.client.client.is_expired", return_value=True) + @patch("skyflow.vault.client.client.log_info") + def test_get_bearer_token_regenerates_on_expiry(self, mock_log, mock_is_expired, mock_generate): + """Expired token is regenerated silently — no exception raised.""" + self.vault_client._VaultClient__bearer_token = "expired_token" + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_PATH) + mock_generate.assert_called_once() + self.assertEqual(result, "new_token") @patch("skyflow.vault.client.client.generate_bearer_token") - @patch("skyflow.vault.client.client.generate_bearer_token_from_creds") + @patch("skyflow.vault.client.client.is_expired", return_value=False) @patch("skyflow.vault.client.client.log_info") - def test_get_bearer_token_with_api_key(self, mock_log_info, mock_generate_bearer_token, - mock_generate_bearer_token_from_creds): - token = self.vault_client.get_bearer_token(CREDENTIALS_WITH_API_KEY) - self.assertEqual(token, CREDENTIALS_WITH_API_KEY["api_key"]) - - def test_update_config(self): - new_config = {"credentials": "new_credentials"} - self.vault_client.update_config(new_config) + def test_get_bearer_token_reuses_valid_cached_token(self, mock_log, mock_is_expired, mock_generate): + """Valid cached token is reused without calling generate_bearer_token.""" + self.vault_client._VaultClient__bearer_token = "valid_token" + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_PATH) + mock_generate.assert_not_called() + self.assertEqual(result, "valid_token") + + # ------------------------------------------------------------------ # + # update_config # + # ------------------------------------------------------------------ # + + def test_update_config_sets_flag(self): + self.vault_client.update_config({"credentials": "new_credentials"}) self.assertTrue(self.vault_client._VaultClient__is_config_updated) self.assertEqual(self.vault_client.get_config()["credentials"], "new_credentials") - def test_get_config(self): - self.assertEqual(self.vault_client.get_config(), CONFIG) + # ------------------------------------------------------------------ # + # API accessor stubs # + # ------------------------------------------------------------------ # - def test_get_common_skyflow_credentials(self): - credentials = {"api_key": "dummy_api_key"} - self.vault_client.set_common_skyflow_credentials(credentials) - self.assertEqual(self.vault_client.get_common_skyflow_credentials(), credentials) + def test_get_records_api(self): + self.vault_client._VaultClient__api_client = MagicMock() + self.assertIsNotNone(self.vault_client.get_records_api()) - def test_get_log_level(self): - log_level = "DEBUG" - self.vault_client.set_logger(log_level, MagicMock()) - self.assertEqual(self.vault_client.get_log_level(), log_level) + def test_get_tokens_api(self): + self.vault_client._VaultClient__api_client = MagicMock() + self.assertIsNotNone(self.vault_client.get_tokens_api()) - def test_get_logger(self): - mock_logger = MagicMock() - self.vault_client.set_logger("INFO", mock_logger) - self.assertEqual(self.vault_client.get_logger(), mock_logger) \ No newline at end of file + def test_get_query_api(self): + self.vault_client._VaultClient__api_client = MagicMock() + self.assertIsNotNone(self.vault_client.get_query_api()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/vault/controller/test__connection.py b/tests/vault/controller/test__connection.py index 4ccad1c7..f073264c 100644 --- a/tests/vault/controller/test__connection.py +++ b/tests/vault/controller/test__connection.py @@ -1,9 +1,11 @@ +import json import unittest -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, MagicMock import requests from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages, parse_invoke_connection_response -from skyflow.utils.enums import RequestMethod +from skyflow.utils._utils import get_data_from_content_type, construct_invoke_connection_request +from skyflow.utils.enums import RequestMethod, ContentType from skyflow.utils._version import SDK_VERSION from skyflow.vault.connection import InvokeConnectionRequest from skyflow.vault.controller import Connection @@ -30,10 +32,16 @@ def setUp(self): self.mock_vault_client = Mock() self.mock_vault_client.get_config.return_value = VAULT_CONFIG self.mock_vault_client.get_bearer_token.return_value = VALID_BEARER_TOKEN + self.mock_vault_client.get_logger.return_value = Mock() + self.mock_vault_client.get_common_skyflow_credentials.return_value = None self.connection = Connection(self.mock_vault_client) + @patch('skyflow.vault.controller._connections.get_credentials') @patch('requests.Session.send') - def test_invoke_success(self, mock_send): + def test_invoke_success(self, mock_send, mock_get_credentials): + # Mock get_credentials to return credentials + mock_get_credentials.return_value = {"api_key": "test_api_key"} + # Mocking successful response mock_response = Mock() mock_response.status_code = SUCCESS_STATUS_CODE @@ -60,9 +68,36 @@ def test_invoke_success(self, mock_send): } self.assertEqual(vars(response), expected_response) self.mock_vault_client.get_bearer_token.assert_called_once() + mock_get_credentials.assert_called_once() + @patch('skyflow.vault.controller._connections.get_credentials') @patch('requests.Session.send') - def test_invoke_invalid_headers(self, mock_send): + def test_invoke_with_x_skyflow_authorization_already_present(self, mock_send, mock_get_credentials): + """Test that X-Skyflow-Authorization is not overwritten if already present in headers.""" + mock_get_credentials.return_value = {"api_key": "test_api_key"} + + mock_response = Mock() + mock_response.status_code = SUCCESS_STATUS_CODE + mock_response.content = SUCCESS_RESPONSE_CONTENT + mock_response.headers = {'x-request-id': 'test-request-id'} + mock_send.return_value = mock_response + + custom_auth = "custom_bearer_token" + request = InvokeConnectionRequest( + method=RequestMethod.POST, + body=VALID_BODY, + headers={"x-skyflow-authorization": custom_auth} + ) + + response = self.connection.invoke(request) + + # Verify bearer token from vault_client is NOT used + self.assertIsNotNone(response) + + @patch('skyflow.vault.controller._connections.get_credentials') + def test_invoke_invalid_headers(self, mock_get_credentials): + mock_get_credentials.return_value = {"api_key": "test_api_key"} + request = InvokeConnectionRequest( method="POST", body=VALID_BODY, @@ -75,8 +110,10 @@ def test_invoke_invalid_headers(self, mock_send): self.connection.invoke(request) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value) - @patch('requests.Session.send') - def test_invoke_invalid_body(self, mock_send): + @patch('skyflow.vault.controller._connections.get_credentials') + def test_invoke_invalid_body(self, mock_get_credentials): + mock_get_credentials.return_value = {"api_key": "test_api_key"} + request = InvokeConnectionRequest( method="POST", body=INVALID_BODY, @@ -89,11 +126,16 @@ def test_invoke_invalid_body(self, mock_send): self.connection.invoke(request) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value) + @patch('skyflow.vault.controller._connections.get_credentials') @patch('requests.Session.send') - def test_invoke_request_error(self, mock_send): + def test_invoke_request_error(self, mock_send, mock_get_credentials): + mock_get_credentials.return_value = {"api_key": "test_api_key"} + mock_response = Mock() mock_response.status_code = FAILURE_STATUS_CODE - mock_response.content = ERROR_RESPONSE_CONTENT + mock_response.content = ERROR_RESPONSE_CONTENT.encode('utf-8') # Convert to bytes + mock_response.headers = {"x-request-id": "test-request-id"} + mock_response.raise_for_status.side_effect = requests.HTTPError("400 Error") mock_send.return_value = mock_response request = InvokeConnectionRequest( @@ -106,9 +148,100 @@ def test_invoke_request_error(self, mock_send): with self.assertRaises(SkyflowError) as context: self.connection.invoke(request) - self.assertEqual(context.exception.message, f'Skyflow Python SDK {SDK_VERSION} Response {ERROR_RESPONSE_CONTENT} is not valid JSON.') - self.assertEqual(context.exception.message, SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format(ERROR_RESPONSE_CONTENT)) - self.assertEqual(context.exception.http_code, 400) + + expected_message = SkyflowMessages.Error.API_ERROR.value.format(FAILURE_STATUS_CODE) + self.assertEqual(context.exception.message, expected_message) + self.assertEqual(context.exception.http_code, FAILURE_STATUS_CODE) + self.assertEqual(context.exception.request_id, "test-request-id") + + @patch('skyflow.vault.controller._connections.get_credentials') + @patch('requests.Session.send') + def test_invoke_session_send_exception(self, mock_send, mock_get_credentials): + """Test handling of generic exception from session.send().""" + mock_get_credentials.return_value = {"api_key": "test_api_key"} + + mock_send.side_effect = Exception("Network error") + + request = InvokeConnectionRequest( + method=RequestMethod.POST, + body=VALID_BODY, + headers=VALID_HEADERS + ) + + with self.assertRaises(SkyflowError) as context: + self.connection.invoke(request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVOKE_CONNECTION_FAILED.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.SERVER_ERROR.value) + + @patch('skyflow.vault.controller._connections.get_credentials') + @patch('requests.Session.send') + def test_invoke_skyflow_error_re_raised(self, mock_send, mock_get_credentials): + """Test that SkyflowError is re-raised without wrapping.""" + mock_get_credentials.return_value = {"api_key": "test_api_key"} + + original_error = SkyflowError("Original error", 401) + mock_send.side_effect = original_error + + request = InvokeConnectionRequest( + method=RequestMethod.POST, + body=VALID_BODY, + headers=VALID_HEADERS + ) + + with self.assertRaises(SkyflowError) as context: + self.connection.invoke(request) + # Should be the same original error + self.assertEqual(context.exception.message, "Original error") + self.assertEqual(context.exception.http_code, 401) + + @patch('skyflow.vault.controller._connections.get_credentials') + @patch('requests.Session.send') + def test_invoke_session_close_called(self, mock_send, mock_get_credentials): + """Test that session.close() is called after send().""" + mock_get_credentials.return_value = {"api_key": "test_api_key"} + + mock_response = Mock() + mock_response.status_code = SUCCESS_STATUS_CODE + mock_response.content = SUCCESS_RESPONSE_CONTENT + mock_response.headers = {'x-request-id': 'test-request-id'} + mock_send.return_value = mock_response + + with patch('requests.Session.close') as mock_close: + request = InvokeConnectionRequest( + method=RequestMethod.GET, + headers=VALID_HEADERS + ) + + response = self.connection.invoke(request) + + # Verify close was called + mock_close.assert_called_once() + + @patch('skyflow.vault.controller._connections.get_credentials') + @patch('skyflow.vault.controller._connections.get_metrics') + @patch('requests.Session.send') + def test_invoke_adds_sky_metadata_header(self, mock_send, mock_get_metrics, mock_get_credentials): + """Test that sky-metadata header is added to request.""" + mock_get_credentials.return_value = {"api_key": "test_api_key"} + mock_get_metrics.return_value = {"sdk_version": SDK_VERSION} + + mock_response = Mock() + mock_response.status_code = SUCCESS_STATUS_CODE + mock_response.content = SUCCESS_RESPONSE_CONTENT + mock_response.headers = {'x-request-id': 'test-request-id'} + mock_send.return_value = mock_response + + request = InvokeConnectionRequest( + method=RequestMethod.POST, + body=VALID_BODY, + headers=VALID_HEADERS + ) + + response = self.connection.invoke(request) + + # Verify get_metrics was called + mock_get_metrics.assert_called_once() + self.assertIsNotNone(response) def test_parse_invoke_connection_response_error_from_client(self): mock_response = Mock(spec=requests.Response) @@ -128,3 +261,415 @@ def test_parse_invoke_connection_response_error_from_client(self): self.assertTrue(any(detail.get('error_from_client') == True for detail in exception.details)) self.assertEqual(exception.request_id, '12345') + @patch('skyflow.vault.controller._connections.get_credentials') + @patch('skyflow.vault.controller._connections.construct_invoke_connection_request') + def test_invoke_construct_request_called(self, mock_construct, mock_get_credentials): + """Test that construct_invoke_connection_request is called with correct parameters.""" + mock_get_credentials.return_value = {"api_key": "test_api_key"} + + mock_prepared_request = Mock(spec=requests.PreparedRequest) + mock_prepared_request.headers = {} + mock_construct.return_value = mock_prepared_request + + with patch('requests.Session.send') as mock_send: + mock_response = Mock() + mock_response.status_code = SUCCESS_STATUS_CODE + mock_response.content = SUCCESS_RESPONSE_CONTENT + mock_response.headers = {'x-request-id': 'test-request-id'} + mock_send.return_value = mock_response + + request = InvokeConnectionRequest( + method=RequestMethod.GET, + headers=VALID_HEADERS + ) + + self.connection.invoke(request) + + # Verify construct was called with connection_url from config + mock_construct.assert_called_once_with( + request, + VAULT_CONFIG["connection_url"], + self.mock_vault_client.get_logger() + ) + + +class TestGetDataFromContentType(unittest.TestCase): + """Tests for get_data_from_content_type covering all supported content types.""" + + DATA = {'key': 'value', 'num': 42} + + # ── JSON ────────────────────────────────────────────────────────────────── + def test_json_content_type_returns_json_string(self): + data, files = get_data_from_content_type(self.DATA, ContentType.JSON.value) + self.assertEqual(data, json.dumps(self.DATA)) + self.assertEqual(files, {}) + + # ── URL-encoded ─────────────────────────────────────────────────────────── + def test_urlencoded_content_type_returns_encoded_string(self): + data, files = get_data_from_content_type({'k': 'v'}, ContentType.URLENCODED.value) + self.assertIn('k=v', data) + self.assertEqual(files, {}) + + def test_urlencoded_nested_dict(self): + payload = {'a': {'b': 'c'}} + data, files = get_data_from_content_type(payload, ContentType.URLENCODED.value) + self.assertIsInstance(data, str) + self.assertIn('c', data) + self.assertEqual(files, {}) + + # ── Form-data ───────────────────────────────────────────────────────────── + def test_formdata_content_type_returns_files_dict(self): + data, files = get_data_from_content_type({'f1': 'v1', 'f2': 'v2'}, ContentType.FORMDATA.value) + self.assertIsNone(data) + self.assertEqual(files, {'f1': (None, 'v1'), 'f2': (None, 'v2')}) + + def test_formdata_converts_values_to_str(self): + data, files = get_data_from_content_type({'num': 99}, ContentType.FORMDATA.value) + self.assertEqual(files['num'], (None, '99')) + + def test_formdata_single_key(self): + data, files = get_data_from_content_type({'only': 'one'}, ContentType.FORMDATA.value) + self.assertIsNone(data) + self.assertIn('only', files) + + # ── XML ─────────────────────────────────────────────────────────────────── + def test_xml_text_xml_content_type_wraps_in_root(self): + data, files = get_data_from_content_type({'key': 'value'}, 'text/xml') + self.assertIn('', data) + self.assertIn('value', data) + self.assertIn('', data) + self.assertEqual(files, {}) + + def test_xml_application_xml_content_type(self): + data, files = get_data_from_content_type({'key': 'value'}, 'application/xml') + self.assertIn('', data) + self.assertIn('value', data) + self.assertEqual(files, {}) + + def test_xml_content_type_enum_value(self): + data, files = get_data_from_content_type({'key': 'value'}, ContentType.XML.value) + self.assertIn('value', data) + self.assertEqual(files, {}) + + def test_xml_non_dict_data_returns_str(self): + data, files = get_data_from_content_type('raw_string', 'text/xml') + self.assertEqual(data, 'raw_string') + self.assertEqual(files, {}) + + # ── HTML ────────────────────────────────────────────────────────────────── + def test_html_content_type_dict_returns_json_string(self): + data, files = get_data_from_content_type(self.DATA, ContentType.HTML.value) + self.assertEqual(data, json.dumps(self.DATA)) + self.assertEqual(files, {}) + + def test_html_text_html_content_type(self): + data, files = get_data_from_content_type(self.DATA, 'text/html') + self.assertEqual(data, json.dumps(self.DATA)) + self.assertEqual(files, {}) + + def test_html_non_dict_data_returns_str(self): + data, files = get_data_from_content_type('raw', ContentType.HTML.value) + self.assertEqual(data, 'raw') + self.assertEqual(files, {}) + + # ── None / unknown ──────────────────────────────────────────────────────── + def test_none_content_type_falls_back_to_json(self): + data, files = get_data_from_content_type(self.DATA, None) + self.assertEqual(data, json.dumps(self.DATA)) + self.assertEqual(files, {}) + + def test_unknown_content_type_falls_back_to_json(self): + data, files = get_data_from_content_type(self.DATA, 'application/octet-stream') + self.assertEqual(data, json.dumps(self.DATA)) + self.assertEqual(files, {}) + + def test_unknown_content_type_non_dict_returns_str(self): + data, files = get_data_from_content_type('blob', 'application/octet-stream') + self.assertEqual(data, 'blob') + self.assertEqual(files, {}) + + +class TestParseInvokeConnectionResponse(unittest.TestCase): + """Tests for parse_invoke_connection_response covering all success and error paths.""" + + def _make_response(self, status_code, content, headers=None, raise_http_error=False): + mock_resp = Mock(spec=requests.Response) + mock_resp.status_code = status_code + if isinstance(content, str): + mock_resp.content = content.encode('utf-8') + else: + mock_resp.content = content + mock_resp.headers = headers or {} + if raise_http_error: + mock_resp.raise_for_status.side_effect = requests.HTTPError() + else: + mock_resp.raise_for_status.return_value = None + return mock_resp + + # ── Success paths ───────────────────────────────────────────────────────── + def test_success_json_content_type_parses_body(self): + resp = self._make_response( + 200, + '{"result": "ok"}', + {'content-type': 'application/json', 'x-request-id': 'req-1'} + ) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.data, {'result': 'ok'}) + self.assertEqual(result.metadata.get('request_id'), 'req-1') + self.assertIsNone(result.errors) + + def test_success_plain_text_content_type_returns_string(self): + resp = self._make_response( + 200, + 'plain text response', + {'content-type': 'text/plain'} + ) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.data, 'plain text response') + + def test_success_no_content_type_tries_json_parse(self): + resp = self._make_response(200, '{"a": 1}', {}) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.data, {'a': 1}) + + def test_success_no_content_type_invalid_json_returns_string(self): + resp = self._make_response(200, 'not json', {}) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.data, 'not json') + + def test_success_no_x_request_id_metadata_is_empty(self): + resp = self._make_response(200, '{}', {'content-type': 'application/json'}) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.metadata, {}) + + def test_success_invalid_json_with_json_content_type_returns_raw_string(self): + resp = self._make_response( + 200, + 'not-json', + {'content-type': 'application/json'} + ) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.data, 'not-json') + + def test_success_bytes_content_decoded(self): + resp = self._make_response(200, b'{"x": 1}', {'content-type': 'application/json'}) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.data, {'x': 1}) + + # ── Error paths — standard Skyflow format ──────────────────────────────── + def test_error_standard_skyflow_format_extracts_message(self): + body = json.dumps({'error': {'message': 'bad input', 'http_code': 400, 'http_status': 'BAD_REQUEST', 'grpc_code': 3, 'details': []}}) + resp = self._make_response(400, body, {'x-request-id': 'r1'}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + e = ctx.exception + self.assertEqual(e.message, 'bad input') + self.assertEqual(e.http_code, 400) + self.assertEqual(e.request_id, 'r1') + self.assertEqual(e.http_status, 'BAD_REQUEST') + self.assertEqual(e.grpc_code, 3) + + def test_error_standard_format_falls_back_to_http_code_when_missing(self): + body = json.dumps({'error': {'message': 'oops'}}) + resp = self._make_response(500, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + self.assertEqual(ctx.exception.http_code, 500) + + def test_error_standard_format_falls_back_to_sdk_message_when_missing(self): + body = json.dumps({'error': {}}) + resp = self._make_response(503, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(503) + self.assertEqual(ctx.exception.message, expected) + + # ── Error paths — string error value ───────────────────────────────────── + def test_error_string_error_value_used_as_message(self): + body = json.dumps({'error': 'gateway timed out'}) + resp = self._make_response(502, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + self.assertEqual(ctx.exception.message, 'gateway timed out') + + def test_error_empty_string_error_value_falls_back_to_sdk_message(self): + body = json.dumps({'error': ''}) + resp = self._make_response(502, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(502) + self.assertEqual(ctx.exception.message, expected) + + # ── Error paths — non-standard JSON ────────────────────────────────────── + def test_error_no_error_key_uses_sdk_message(self): + body = json.dumps({'message': 'something went wrong'}) + resp = self._make_response(500, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(500) + self.assertEqual(ctx.exception.message, expected) + + def test_error_non_dict_json_body_uses_sdk_message(self): + body = json.dumps(['list', 'not', 'dict']) + resp = self._make_response(500, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(500) + self.assertEqual(ctx.exception.message, expected) + + def test_error_numeric_error_value_uses_sdk_message(self): + body = json.dumps({'error': 12345}) + resp = self._make_response(500, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(500) + self.assertEqual(ctx.exception.message, expected) + + # ── Error paths — non-JSON / empty body ────────────────────────────────── + def test_error_empty_body_uses_sdk_message(self): + resp = self._make_response(502, '', {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(502) + self.assertEqual(ctx.exception.message, expected) + self.assertEqual(ctx.exception.http_code, 502) + + def test_error_html_body_uses_sdk_message(self): + resp = self._make_response(502, 'Bad Gateway', {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(502) + self.assertEqual(ctx.exception.message, expected) + + def test_error_plain_text_body_uses_sdk_message(self): + resp = self._make_response(503, 'Service Unavailable', {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(503) + self.assertEqual(ctx.exception.message, expected) + + # ── error-from-client header ────────────────────────────────────────────── + def test_error_from_client_true_appended_to_details(self): + body = json.dumps({'error': {'message': 'client error', 'http_code': 400, 'details': []}}) + resp = self._make_response(400, body, {'error-from-client': 'true', 'x-request-id': 'r2'}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + self.assertTrue(any(d.get('error_from_client') is True for d in ctx.exception.details)) + + def test_error_from_client_false_appended_to_details(self): + body = json.dumps({'error': {'message': 'server error', 'http_code': 500}}) + resp = self._make_response(500, body, {'error-from-client': 'false'}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + self.assertTrue(any(d.get('error_from_client') is False for d in ctx.exception.details)) + + def test_error_from_client_initialises_details_when_none(self): + body = json.dumps({'error': {'message': 'err', 'http_code': 400}}) + resp = self._make_response(400, body, {'error-from-client': 'true'}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + self.assertIsNotNone(ctx.exception.details) + self.assertTrue(len(ctx.exception.details) > 0) + + +class TestConstructInvokeConnectionRequest(unittest.TestCase): + """Tests for construct_invoke_connection_request covering method, body, headers, path/query params.""" + + BASE_URL = 'https://example.com/api' + LOGGER = Mock() + + def _make_request(self, method=RequestMethod.POST, body=None, headers=None, + path_params=None, query_params=None): + return InvokeConnectionRequest( + method=method, + body=body, + headers=headers, + path_params=path_params or {}, + query_params=query_params or {} + ) + + def test_post_with_json_body_prepares_request(self): + req = self._make_request(body={'k': 'v'}, headers={'Content-Type': 'application/json'}) + prepared = construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertEqual(prepared.method, 'POST') + self.assertIn('k', prepared.body) + + def test_get_with_no_body(self): + req = self._make_request(method=RequestMethod.GET) + prepared = construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertEqual(prepared.method, 'GET') + + def test_urlencoded_body_is_form_encoded(self): + req = self._make_request( + body={'field': 'val'}, + headers={'Content-Type': 'application/x-www-form-urlencoded'} + ) + prepared = construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertIn('field=val', prepared.body) + + def test_formdata_body_produces_multipart_request(self): + req = self._make_request( + body={'file_field': 'data'}, + headers={'Content-Type': 'multipart/form-data'} + ) + prepared = construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertEqual(prepared.method, 'POST') + self.assertIsNotNone(prepared.body) + + def test_xml_body_contains_xml_tags(self): + req = self._make_request( + body={'item': 'data'}, + headers={'Content-Type': 'text/xml'} + ) + prepared = construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertIn('', prepared.body) + + def test_path_params_substituted_in_url(self): + req = self._make_request( + method=RequestMethod.GET, + path_params={'id': '123'} + ) + url_with_placeholder = 'https://example.com/api/{id}/resource' + prepared = construct_invoke_connection_request(req, url_with_placeholder, self.LOGGER) + self.assertIn('123', prepared.url) + self.assertNotIn('{id}', prepared.url) + + def test_query_params_appear_in_url(self): + req = self._make_request( + method=RequestMethod.GET, + query_params={'page': '1', 'limit': '10'} + ) + prepared = construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertIn('page=1', prepared.url) + self.assertIn('limit=10', prepared.url) + + def test_invalid_headers_raises_skyflow_error(self): + req = InvokeConnectionRequest(method=RequestMethod.POST, headers='bad-headers') + with self.assertRaises(SkyflowError) as ctx: + construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertEqual(ctx.exception.message, SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value) + + def test_invalid_body_raises_skyflow_error(self): + req = InvokeConnectionRequest( + method=RequestMethod.POST, + body='not-a-dict', + headers={'Content-Type': 'application/json'} + ) + with self.assertRaises(SkyflowError) as ctx: + construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertEqual(ctx.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value) + + def test_invalid_method_raises_skyflow_error(self): + req = InvokeConnectionRequest(method='INVALID_METHOD') + with self.assertRaises(SkyflowError) as ctx: + construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertEqual(ctx.exception.message, SkyflowMessages.Error.INVALID_REQUEST_METHOD.value) + + def test_trailing_slash_stripped_from_url(self): + req = self._make_request(method=RequestMethod.GET) + prepared = construct_invoke_connection_request(req, self.BASE_URL + '/', self.LOGGER) + self.assertNotIn('//', prepared.url.replace('https://', '')) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/vault/controller/test__detect.py b/tests/vault/controller/test__detect.py index c2f9a861..b86087f5 100644 --- a/tests/vault/controller/test__detect.py +++ b/tests/vault/controller/test__detect.py @@ -2,6 +2,7 @@ from unittest.mock import Mock, patch, MagicMock import base64 import os +import tempfile from skyflow.error import SkyflowError from skyflow.generated.rest import WordCharacterCount from skyflow.utils import SkyflowMessages @@ -513,16 +514,12 @@ def test_get_detect_run_in_progress_status(self, mock_validate): self.vault_client.get_detect_file_api.return_value = files_api - # Execute - with patch.object(self.detect, "_Detect__parse_deidentify_file_response") as mock_parse: - result = self.detect.get_detect_run(req) + # Execute — IN_PROGRESS is returned directly without going through the parser + result = self.detect.get_detect_run(req) - # Verify IN_PROGRESS handling - mock_parse.assert_called_once() - args = mock_parse.call_args[0][0] - self.assertIsInstance(args, DeidentifyFileResponse) - self.assertEqual(args.status, 'IN_PROGRESS') - self.assertEqual(args.run_id, run_id) + self.assertIsInstance(result, DeidentifyFileResponse) + self.assertEqual(result.status, 'IN_PROGRESS') + self.assertEqual(result.run_id, run_id) def test_get_transformations_with_shift_dates(self): @@ -711,3 +708,98 @@ def test_deidentify_file_using_file_path(self, mock_open, mock_basename, mock_ba self.assertIsNone(result.page_count) self.assertIsNone(result.slide_count) self.assertEqual(result.entities, []) + + def test_poll_for_processed_file_exception(self): + files_api = Mock() + files_api.with_raw_response = files_api + files_api.get_run.side_effect = Exception("poll error") + self.vault_client.get_detect_file_api.return_value = files_api + with self.assertRaises(Exception): + self.detect._Detect__poll_for_processed_file("runid", max_wait_time=5) + + def test_save_output_directory_not_exists(self): + output = Mock() + output.processedFile = base64.b64encode(b"data").decode() + output.processedFileType = "redacted_file" + output.processedFileExtension = "txt" + response = Mock() + response.output = [output] + with patch("skyflow.vault.controller._detect.os.path.exists", return_value=False): + self.detect._Detect__save_deidentify_file_response_output( + response, "/nonexistent_dir", "file.txt", "file" + ) + + def test_save_output_second_non_redacted_item(self): + with tempfile.TemporaryDirectory() as tmp_dir: + output1 = Mock() + output1.processedFile = base64.b64encode(b"data1").decode() + output1.processedFileType = "redacted_file" + output1.processedFileExtension = "txt" + output2 = Mock() + output2.processedFile = base64.b64encode(b"data2").decode() + output2.processedFileType = "entities" + output2.processedFileExtension = "json" + response = Mock() + response.output = [output1, output2] + self.detect._Detect__save_deidentify_file_response_output( + response, tmp_dir, "original.txt", "original" + ) + + def test_save_output_path_traversal_blocked(self): + output = Mock() + output.processedFile = base64.b64encode(b"data").decode() + output.processedFileType = "redacted_file" + output.processedFileExtension = "txt" + response = Mock() + response.output = [output] + call_count = [0] + + def fake_realpath(p): + call_count[0] += 1 + if call_count[0] == 1: + return "/safe_dir" + return "/outside/path" + + with patch("skyflow.vault.controller._detect.os.path.exists", return_value=True), \ + patch("skyflow.vault.controller._detect.os.path.realpath", side_effect=fake_realpath): + self.detect._Detect__save_deidentify_file_response_output( + response, "/safe_dir", "file.txt", "file" + ) + + def test_save_output_write_exception(self): + with tempfile.TemporaryDirectory() as tmp_dir: + output = Mock() + output.processedFile = base64.b64encode(b"data").decode() + output.processedFileType = "redacted_file" + output.processedFileExtension = "txt" + response = Mock() + response.output = [output] + with patch("skyflow.vault.controller._detect.base64.b64decode", + side_effect=Exception("decode error")), \ + self.assertRaises(Exception): + self.detect._Detect__save_deidentify_file_response_output( + response, tmp_dir, "file.txt", "file" + ) + + @patch("skyflow.vault.controller._detect.validate_deidentify_file_request") + @patch("skyflow.vault.controller._detect.base64") + def test_deidentify_file_api_error_inside_try(self, mock_base64, mock_validate): + file_content = b"test content" + file_obj = Mock() + file_obj.read.return_value = file_content + file_obj.name = "test.txt" + mock_base64.b64encode.return_value.decode.return_value = "encoded" + req = DeidentifyFileRequest(file=FileInput(file=file_obj)) + req.entities = [] + req.token_format = None + req.allow_regex_list = [] + req.restrict_regex_list = [] + req.transformations = None + req.output_directory = None + req.wait_time = None + files_api = Mock() + files_api.with_raw_response = files_api + files_api.deidentify_text.side_effect = Exception("API error inside try") + self.vault_client.get_detect_file_api.return_value = files_api + with self.assertRaises(Exception): + self.detect.deidentify_file(req) diff --git a/tests/vault/controller/test__vault.py b/tests/vault/controller/test__vault.py index 4e1a0dda..993cd72a 100644 --- a/tests/vault/controller/test__vault.py +++ b/tests/vault/controller/test__vault.py @@ -722,6 +722,26 @@ def test_upload_file_with_missing_file_source(self, mock_validate): self.assertEqual(error.exception.message, SkyflowMessages.Error.MISSING_FILE_SOURCE.value) mock_validate.assert_called_once_with(self.vault_client.get_logger(), request) + @patch("skyflow.vault.controller._vault.validate_file_upload_request") + def test_upload_file_without_skyflow_id_successful(self, mock_validate): + """Test upload_file succeeds when skyflow_id is None (it is optional).""" + request = FileUploadRequest( + table="test_table", + column_name="file_column", + file_path="/path/to/test.txt", + ) + mocked_open = mock_open_func(read_data=b"test file content") + mock_api_response = Mock() + mock_api_response.data = Mock(skyflow_id="generated-id-123") + records_api = self.vault_client.get_records_api.return_value + records_api.with_raw_response.upload_file_v_2.return_value = mock_api_response + with patch('builtins.open', mocked_open): + result = self.vault.upload_file(request) + mock_validate.assert_called_once_with(self.vault_client.get_logger(), request) + self.assertIsNone(request.skyflow_id) + self.assertEqual(result.skyflow_id, "generated-id-123") + self.assertIsNone(result.errors) + class TestFileUploadValidation(unittest.TestCase): def setUp(self): self.logger = Mock() @@ -874,3 +894,38 @@ def test_validate_missing_file_source(self): with self.assertRaises(SkyflowError) as error: validate_file_upload_request(self.logger, request) self.assertEqual(error.exception.message, SkyflowMessages.Error.MISSING_FILE_SOURCE.value) + + def test_validate_none_skyflow_id_is_allowed(self): + """Test that skyflow_id=None passes validation (it is optional).""" + request = FileUploadRequest( + table="test_table", + column_name="file_column", + base64="dGVzdCBmaWxlIGNvbnRlbnQ=", + file_name="test.txt" + ) + self.assertIsNone(request.skyflow_id) + validate_file_upload_request(self.logger, request) + + @patch('os.path.exists') + @patch('os.path.isfile') + def test_validate_file_path_without_skyflow_id(self, mock_isfile, mock_exists): + """Test validation succeeds with file_path and no skyflow_id.""" + mock_exists.return_value = True + mock_isfile.return_value = True + request = FileUploadRequest( + table="test_table", + column_name="file_column", + file_path="/path/to/file.txt" + ) + validate_file_upload_request(self.logger, request) + + def test_validate_file_object_without_skyflow_id(self): + """Test validation succeeds with file_object and no skyflow_id.""" + mock_file = Mock() + mock_file.seek = Mock() + request = FileUploadRequest( + table="test_table", + column_name="file_column", + file_object=mock_file + ) + validate_file_upload_request(self.logger, request)