General ML Smell Detector

Source Code

   1import os
   2import sys
   3from typing import Any, Dict, List
   4
   5import astroid
   6from astroid import nodes
   7
   8
   9class ML_SmellDetector:
  10    """A detector for identifying common code smells in machine learning code.
  11
  12    This class analyzes Python source code to detect potential issues and anti-patterns
  13    commonly found in machine learning projects, such as data leakage, improper feature
  14    scaling, missing cross-validation, and more.
  15    """
  16
  17    def __init__(self):
  18        """Initialize the ML smell detector with empty collections for tracking analysis results."""
  19        self.smells: List[Dict[str, Any]] = []
  20        self.imports: Dict[str, Any] = {}
  21        self.variables: Dict[str, Any] = {}
  22        self.functions: Dict[str, Any] = {}
  23        self.classes: Dict[str, Any] = {}
  24
  25    def add_smell(self, smell: str, node: nodes.NodeNG, file_path: str):
  26        """Add a detected code smell to the collection.
  27
  28        Args:
  29            smell: Description of the detected smell
  30            node: The AST node where the smell was detected
  31            file_path: Path to the file containing the smell
  32        """
  33        self.smells.append({
  34            "smell": smell,
  35            "line_number": node.lineno,
  36            "code_snippet": node.as_string(),
  37            "file_path": file_path
  38        })
  39
  40    def detect_smells(self, file_path: str) -> List[Dict[str, Any]]:
  41        """Analyze a Python file for ML-related code smells.
  42
  43        Args:
  44            file_path: Path to the Python file to analyze
  45
  46        Returns:
  47            List of dictionaries containing detected smell information
  48        """
  49        try:
  50            with open(file_path, 'r') as file:
  51                content = file.read()
  52            module_name = os.path.splitext(os.path.basename(file_path))[0]
  53            module = astroid.parse(content, module_name=module_name)
  54
  55            # Check if any ML-related packages are imported
  56            ml_packages = ['pandas', 'numpy', 'sklearn', 'tensorflow', 'torch', 'transformers']
  57            if any(self.is_package_used(module, package) for package in ml_packages):
  58                self.visit_module(module, file_path)
  59            else:
  60                print(f"Skipping ML smell detection for {file_path}: No ML-related packages imported", file=sys.stderr)
  61        except astroid.exceptions.AstroidSyntaxError as e:
  62            print(f"Error parsing {file_path}: {str(e)}", file=sys.stderr)
  63        except Exception as e:
  64            print(f"Unexpected error while processing {file_path}: {str(e)}", file=sys.stderr)
  65        return self.smells
  66
  67    def is_package_used(self, node: nodes.Module, package: str) -> bool:
  68        """Check if a specific package is imported in the module.
  69
  70        Args:
  71            node: AST node representing the module
  72            package: Name of the package to check for
  73
  74        Returns:
  75            True if the package is imported, False otherwise
  76        """
  77        for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
  78            if isinstance(import_node, nodes.Import):
  79                if any(name.split('.')[0] == package for name, _ in import_node.names):
  80                    return True
  81            elif isinstance(import_node, nodes.ImportFrom):
  82                if import_node.modname.split('.')[0] == package:
  83                    return True
  84        return False
  85
  86    def visit_module(self, node: nodes.Module, file_path: str):
  87        """Run all smell detection checks on a module.
  88
  89        Args:
  90            node: AST node representing the module
  91            file_path: Path to the file being analyzed
  92        """
  93        self.check_imports(node, file_path)
  94        self.check_data_leakage(node, file_path)
  95        self.check_magic_numbers(node, file_path)
  96        self.check_feature_scaling(node, file_path)
  97        self.check_cross_validation(node, file_path)
  98        self.check_imbalanced_dataset(node, file_path)
  99        self.check_feature_selection(node, file_path)
 100        self.check_metric_selection(node, file_path)
 101        self.check_model_persistence(node, file_path)
 102        self.check_reproducibility(node, file_path)
 103        self.check_data_loading(node, file_path)
 104        self.check_unused_features(node, file_path)
 105        self.check_overfit_prone_practices(node, file_path)
 106        self.check_error_handling(node, file_path)
 107        self.check_hardcoded_filepaths(node, file_path)
 108        self.check_documentation(node, file_path)
 109
 110    def check_imports(self, node: nodes.Module, file_path: str):
 111        """Track imports used in the module for later analysis.
 112
 113        Args:
 114            node: AST node representing the module
 115            file_path: Path to the file being analyzed
 116        """
 117        for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
 118            if isinstance(import_node, nodes.Import):
 119                for name, alias in import_node.names:
 120                    self.imports[alias or name] = name
 121            elif isinstance(import_node, nodes.ImportFrom):
 122                for name, alias in import_node.names:
 123                    self.imports[alias or name] = f"{import_node.modname}.{name}"
 124
 125    @staticmethod
 126    def _is_inside_function(call_node):
 127        """Return True if call_node is nested inside a FunctionDef."""
 128        parent = call_node.parent
 129        while parent:
 130            if isinstance(parent, nodes.FunctionDef):
 131                return True
 132            if isinstance(parent, nodes.Module):
 133                break
 134            parent = parent.parent
 135        return False
 136
 137    def _check_leakage_in_scope(self, scope_node, report_node, file_path):
 138        """Check a single scope (module or function) for fit-before-split leakage."""
 139        preprocessing_before_split = False
 140        train_test_split_pos = -1
 141        last_preprocessing_pos = -1
 142
 143        for call in scope_node.nodes_of_class(nodes.Call):
 144            if call.func.as_string().endswith(('fit', 'fit_transform')):
 145                preprocessing_before_split = True
 146                last_preprocessing_pos = call.lineno
 147            elif 'train_test_split' in call.func.as_string():
 148                train_test_split_pos = call.lineno
 149
 150        if preprocessing_before_split and train_test_split_pos > 0:
 151            if last_preprocessing_pos < train_test_split_pos:
 152                self.add_smell(
 153                    "Potential data leakage: Preprocessing applied before train-test split",
 154                    report_node,
 155                    file_path)
 156
 157    def check_data_leakage(self, node: nodes.Module, file_path: str):
 158        """Detect potential data leakage from preprocessing before train-test split.
 159
 160        Args:
 161            node: AST node representing the module
 162            file_path: Path to the file being analyzed
 163        """
 164        # Check inside each function/method (existing behaviour)
 165        for func in node.nodes_of_class(nodes.FunctionDef):
 166            self._check_leakage_in_scope(func, func, file_path)
 167
 168        # Also check module-level code that sits outside any function
 169        preprocessing_before_split = False
 170        train_test_split_pos = -1
 171        last_preprocessing_pos = -1
 172        for call in node.nodes_of_class(nodes.Call):
 173            if self._is_inside_function(call):
 174                continue
 175            if call.func.as_string().endswith(('fit', 'fit_transform')):
 176                preprocessing_before_split = True
 177                last_preprocessing_pos = call.lineno
 178            elif 'train_test_split' in call.func.as_string():
 179                train_test_split_pos = call.lineno
 180        if preprocessing_before_split and train_test_split_pos > 0:
 181            if last_preprocessing_pos < train_test_split_pos:
 182                self.add_smell(
 183                    "Potential data leakage: Preprocessing applied before train-test split "
 184                    "(module-level code)",
 185                    node,
 186                    file_path)
 187
 188    def check_magic_numbers(self, node: nodes.Module, file_path: str):
 189        """Detect magic numbers in ML-related code.
 190
 191        Args:
 192            node: AST node representing the module
 193            file_path: Path to the file being analyzed
 194        """
 195        # Common acceptable values that shouldn't trigger warnings
 196        acceptable_values = {0, 1, -1, 100, 0.5, 2}  # Common ML-related constants
 197        for assign in node.nodes_of_class(nodes.Assign):
 198            if isinstance(assign.value, nodes.Const) and isinstance(assign.value.value, (int, float)):
 199                # Skip if it's an acceptable value
 200                if assign.value.value in acceptable_values:
 201                    continue
 202                # Skip if the variable name suggests it's a legitimate constant
 203                if any(assign.targets[0].as_string().lower().startswith(prefix) for prefix in
 204                       ['num_', 'size_', 'batch_', 'epoch', 'learning_rate', 'lr_', 'threshold_']):
 205                    continue
 206                self.add_smell(f"Magic number detected: {assign.value.value}", assign, file_path)
 207
 208    def check_feature_scaling(self, node: nodes.Module, file_path: str):
 209        """Detect inconsistent feature scaling methods across the codebase.
 210
 211        Looks for multiple different scaling methods (StandardScaler, MinMaxScaler, etc.)
 212        being used, which could lead to inconsistent results.
 213
 214        Args:
 215            node: AST node representing the module
 216            file_path: Path to the file being analyzed
 217        """
 218        scaling_methods = ['StandardScaler', 'MinMaxScaler', 'RobustScaler']
 219        scalers_used = set()
 220
 221        for call in node.nodes_of_class(nodes.Call):
 222            if any(method in call.func.as_string() for method in scaling_methods):
 223                scalers_used.add(call.func.as_string())
 224
 225        # Only raise warning if multiple different scaling methods are used
 226        if len(scalers_used) > 1:
 227            scalers = ', '.join(scalers_used)
 228            self.add_smell(
 229                f"Inconsistent scaling methods detected: {scalers}. Consider using the same scaler across the pipeline.",
 230                call,
 231                file_path)
 232
 233    def check_cross_validation(self, node: nodes.Module, file_path: str):
 234        """Check if cross-validation is properly implemented in model training.
 235
 236        Detects if common cross-validation methods (KFold, cross_val_score, etc.)
 237        are missing in model training code.
 238
 239        Args:
 240            node: AST node representing the module
 241            file_path: Path to the file being analyzed
 242        """
 243        cv_methods = [
 244            'cross_val_score',
 245            'KFold',
 246            'cross_validate',
 247            'GridSearchCV',
 248            'RandomizedSearchCV',
 249            'TimeSeriesSplit']
 250        cv_detected = False
 251        is_training_file = False
 252
 253        # Skip if this is likely not a main training file
 254        if any(pattern in file_path.lower() for pattern in [
 255            'test_', 'utils', 'helper', 'preprocessing', 'visualization',
 256            'evaluate', 'predict', 'inference', 'deploy'
 257        ]):
 258            return
 259
 260        # Check imports to see if this is likely a training file
 261        training_imports = {'sklearn.model_selection', 'sklearn.linear_model', 'sklearn.ensemble',
 262                            'tensorflow', 'torch', 'xgboost', 'lightgbm'}
 263        has_training_imports = False
 264        for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
 265            if isinstance(import_node, nodes.ImportFrom):
 266                if any(imp in import_node.modname for imp in training_imports):
 267                    has_training_imports = True
 268                    break
 269
 270        if not has_training_imports:
 271            return
 272
 273        for call in node.nodes_of_class(nodes.Call):
 274            # Check if file contains model training code
 275            if any(method in call.func.as_string() for method in ['fit', 'train', 'compile']):
 276                # Skip if it's in a test method
 277                if any(ancestor.name.startswith('test_') for ancestor in call.node_ancestors()
 278                       if isinstance(ancestor, nodes.FunctionDef)):
 279                    continue
 280                is_training_file = True
 281
 282            # Check for CV methods
 283            if any(method in call.func.as_string() for method in cv_methods):
 284                cv_detected = True
 285                break
 286
 287            # Also check for custom CV implementations
 288            if ('split' in call.func.as_string() and
 289                    any(val in call.as_string() for val in ['fold', 'cv', 'validation'])):
 290                cv_detected = True
 291                break
 292
 293        # Only raise warning if it's a substantial training file (has multiple model-related calls)
 294        model_related_calls = sum(1 for call in node.nodes_of_class(nodes.Call)
 295                                  if any(term in call.func.as_string()
 296                                         for term in ['fit', 'train', 'predict', 'score']))
 297
 298        if (is_training_file and not cv_detected and model_related_calls >= 2
 299                and not any(term in file_path.lower() for term in ['quick', 'example', 'demo'])):
 300            self.add_smell(
 301                "Cross-validation not detected in model training code. "
 302                "Consider using cross-validation for more robust evaluation.",
 303                node, file_path
 304            )
 305
 306    def check_imbalanced_dataset(self, node: nodes.Module, file_path: str):
 307        """Check if imbalanced dataset handling techniques are used in classification tasks.
 308
 309        Looks for common techniques like SMOTE, class weights, or stratification
 310        when dealing with classification problems.
 311
 312        Args:
 313            node: AST node representing the module
 314            file_path: Path to the file being analyzed
 315        """
 316        # Skip if this is likely not a main training file
 317        if any(pattern in file_path.lower() for pattern in [
 318            'test_', 'utils', 'helper', 'visualization', 'evaluate',
 319            'predict', 'inference', 'deploy', 'preprocess'
 320        ]):
 321            return
 322
 323        balance_methods = [
 324            'SMOTE', 'class_weight', 'StratifiedKFold', 'RandomOverSampler',
 325            'RandomUnderSampler', 'sample_weight', 'balanced_accuracy',
 326            'WeightedRandomSampler', 'BalancedBaggingClassifier'
 327        ]
 328        imbalance_handling = False
 329        is_classification = False
 330        has_data_processing = False
 331
 332        # Check imports first
 333        classification_imports = {
 334            'sklearn.linear_model', 'sklearn.ensemble', 'sklearn.svm',
 335            'sklearn.tree', 'sklearn.naive_bayes', 'xgboost', 'lightgbm'
 336        }
 337        has_classification_imports = False
 338        for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
 339            if isinstance(import_node, nodes.ImportFrom):
 340                if any(imp in import_node.modname for imp in classification_imports):
 341                    has_classification_imports = True
 342                    break
 343
 344        if not has_classification_imports:
 345            return
 346
 347        # Check if this is a classification task
 348        for call in node.nodes_of_class(nodes.Call):
 349            if any(clf in call.func.as_string() for clf in [
 350                'Classifier', 'LogisticRegression', 'SVC', 'RandomForestClassifier',
 351                'GradientBoostingClassifier', 'XGBClassifier', 'LGBMClassifier',
 352                'DecisionTreeClassifier', 'KNeighborsClassifier'
 353            ]):
 354                is_classification = True
 355
 356            # Check for data processing/analysis that might indicate class distribution checks
 357            if any(term in call.func.as_string() for term in [
 358                'value_counts', 'unique', 'hist', 'countplot', 'distribution',
 359                'balance_ratio', 'class_distribution'
 360            ]):
 361                has_data_processing = True
 362
 363            # Check for imbalance handling methods
 364            if any(method in call.func.as_string() for method in balance_methods):
 365                imbalance_handling = True
 366                break
 367
 368            # Check for custom handling in strings (e.g., parameter names)
 369            if isinstance(call.func, nodes.Attribute):
 370                if any(term in str(call.args) + str(call.keywords) for term in [
 371                    'weight', 'balanced', 'stratif'
 372                ]):
 373                    imbalance_handling = True
 374                    break
 375
 376        # Count model-related calls to ensure it's a substantial training file
 377        model_related_calls = sum(1 for call in node.nodes_of_class(nodes.Call)
 378                                  if any(term in call.func.as_string()
 379                                         for term in ['fit', 'train', 'predict', 'score']))
 380
 381        # Only raise warning if:
 382        # 1. It's a classification task
 383        # 2. No imbalance handling detected
 384        # 3. Has multiple model-related calls
 385        # 4. Has data processing (suggesting actual data analysis)
 386        # 5. Not a quick example/demo
 387        if (is_classification and not imbalance_handling and
 388            model_related_calls >= 2 and has_data_processing and
 389                not any(term in file_path.lower() for term in ['quick', 'example', 'demo'])):
 390
 391            self.add_smell(
 392                "No imbalanced dataset handling detected in classification task. "
 393                "Consider techniques like SMOTE or class weights if dealing with imbalanced data.",
 394                node, file_path
 395            )
 396
 397    def check_feature_selection(self, node: nodes.Module, file_path: str):
 398        """Detect feature selection practices and validate their implementation.
 399
 400        Ensures feature selection is performed with proper validation strategy
 401        to avoid selection bias.
 402
 403        Args:
 404            node: AST node representing the module
 405            file_path: Path to the file being analyzed
 406        """
 407        # Skip if this is likely not a feature selection file
 408        if any(pattern in file_path.lower() for pattern in [
 409            'test_', 'utils', 'helper', 'visualization',
 410            'predict', 'inference', 'deploy', 'evaluate'
 411        ]):
 412            return
 413
 414        feature_selection_methods = [
 415            'SelectKBest', 'RFE', 'SelectFromModel', 'PCA', 'VarianceThreshold',
 416            'mutual_info', 'chi2', 'f_classif', 'SelectPercentile',
 417            'GenericUnivariateSelect', 'RFECV', 'SequentialFeatureSelector'
 418        ]
 419        validation_methods = [
 420            'cross_val_score', 'train_test_split', 'GridSearchCV',
 421            'RandomizedSearchCV', 'KFold', 'StratifiedKFold',
 422            'TimeSeriesSplit', 'cross_validate'
 423        ]
 424        feature_selection = False
 425        validation_detected = False
 426        has_ml_imports = False
 427
 428        # Check for relevant imports first
 429        for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
 430            if isinstance(import_node, nodes.ImportFrom):
 431                if any(pkg in import_node.modname for pkg in [
 432                    'sklearn.feature_selection', 'sklearn.decomposition',
 433                    'sklearn.model_selection'
 434                ]):
 435                    has_ml_imports = True
 436                    break
 437
 438        if not has_ml_imports:
 439            return
 440
 441        for call in node.nodes_of_class(nodes.Call):
 442            # Check for feature selection methods
 443            if any(method in call.func.as_string() for method in feature_selection_methods):
 444                feature_selection = True
 445
 446            # Check for validation methods
 447            if any(method in call.func.as_string() for method in validation_methods):
 448                validation_detected = True
 449
 450            # Check for custom validation in parameter names or strings
 451            if isinstance(call.func, nodes.Attribute):
 452                if any(term in str(call.args) + str(call.keywords) for term in
 453                       ['valid', 'test', 'split', 'cv', 'fold']):
 454                    validation_detected = True
 455
 456        # Count substantial ML operations
 457        ml_operations = sum(1 for call in node.nodes_of_class(nodes.Call)
 458                            if any(term in call.func.as_string()
 459                                   for term in ['fit', 'transform', 'predict', 'score']))
 460
 461        if (feature_selection and not validation_detected and
 462            ml_operations >= 2 and
 463                not any(term in file_path.lower() for term in ['quick', 'example', 'demo'])):
 464            self.add_smell(
 465                "Feature selection detected without clear validation strategy. "
 466                "Ensure it's applied with proper validation to avoid selection bias.",
 467                node, file_path
 468            )
 469
 470    def check_metric_selection(self, node: nodes.Module, file_path: str):
 471        """Validate the choice of evaluation metrics for ML models.
 472
 473        Ensures appropriate metrics are used for classification and regression tasks,
 474        warning against using only basic metrics like accuracy.
 475
 476        Args:
 477            node: AST node representing the module
 478            file_path: Path to the file being analyzed
 479        """
 480        # Skip if this is likely not an evaluation file
 481        if any(pattern in file_path.lower() for pattern in [
 482            'test_', 'utils', 'helper', 'preprocess',
 483            'data', 'feature', 'transform'
 484        ]):
 485            return
 486
 487        metrics = set()
 488        is_classification = False
 489        is_regression = False
 490        has_ml_imports = False
 491
 492        # Check imports first
 493        for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
 494            if isinstance(import_node, nodes.ImportFrom):
 495                if any(pkg in import_node.modname for pkg in [
 496                    'sklearn.metrics', 'sklearn.model_selection',
 497                    'sklearn.linear_model', 'sklearn.ensemble'
 498                ]):
 499                    has_ml_imports = True
 500                    break
 501
 502        if not has_ml_imports:
 503            return
 504
 505        # Determine if it's classification or regression
 506        for call in node.nodes_of_class(nodes.Call):
 507            if any(clf in call.func.as_string() for clf in [
 508                'Classifier', 'LogisticRegression', 'SVC', 'RandomForestClassifier',
 509                'GradientBoostingClassifier', 'XGBClassifier', 'LGBMClassifier'
 510            ]):
 511                is_classification = True
 512            elif any(reg in call.func.as_string() for reg in [
 513                'Regressor', 'LinearRegression', 'SVR', 'RandomForestRegressor',
 514                'GradientBoostingRegressor', 'XGBRegressor', 'LGBMRegressor'
 515            ]):
 516                is_regression = True
 517
 518            # Collect metrics
 519            if any(metric in call.func.as_string() for metric in [
 520                'accuracy_score', 'precision_score', 'recall_score', 'f1_score',
 521                'mean_squared_error', 'r2_score', 'mean_absolute_error',
 522                'roc_auc_score', 'average_precision_score', 'confusion_matrix',
 523                'classification_report', 'explained_variance_score'
 524            ]):
 525                metrics.add(call.func.as_string())
 526
 527            # Check for custom metric implementations
 528            if isinstance(call.func, nodes.Attribute):
 529                if any(term in str(call.args) + str(call.keywords) for term in [
 530                    'metric', 'score', 'evaluation', 'performance'
 531                ]):
 532                    metrics.add('custom_metric')
 533
 534        # Count substantial ML operations
 535        ml_operations = sum(1 for call in node.nodes_of_class(nodes.Call)
 536                            if any(term in call.func.as_string()
 537                                   for term in ['fit', 'predict', 'score', 'evaluate']))
 538
 539        # Only raise warnings for specific cases with substantial ML usage
 540        if ml_operations >= 2 and not any(term in file_path.lower() for term in ['quick', 'example', 'demo']):
 541            if is_classification and 'accuracy_score' in metrics and len(metrics) == 1:
 542                self.add_smell(
 543                    "Only accuracy metric detected for classification. "
 544                    "Consider adding precision, recall, or F1-score for a more comprehensive evaluation.",
 545                    node, file_path
 546                )
 547            elif is_regression and 'mean_squared_error' in metrics and len(metrics) == 1:
 548                self.add_smell(
 549                    "Only MSE detected for regression. "
 550                    "Consider adding R2 score or MAE for a more comprehensive evaluation.",
 551                    node, file_path
 552                )
 553
 554    def check_model_persistence(self, node: nodes.Module, file_path: str):
 555        """Check model saving practices and associated preprocessing steps.
 556
 557        Ensures models are saved with their preprocessing steps and proper versioning
 558        for reproducibility.
 559
 560        Args:
 561            node: AST node representing the module
 562            file_path: Path to the file being analyzed
 563        """
 564        # Skip if this is likely not a model saving file
 565        if any(pattern in file_path.lower() for pattern in [
 566            'test_', 'utils', 'helper', 'data', 'preprocess',
 567            'explore', 'analyze', 'visualize'
 568        ]):
 569            return
 570
 571        model_save = False
 572        preprocessing_save = False
 573        version_control = False
 574        has_ml_imports = False
 575
 576        # Check for relevant imports first
 577        for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
 578            if isinstance(import_node, nodes.ImportFrom):
 579                if any(pkg in import_node.modname for pkg in [
 580                    'sklearn', 'tensorflow', 'torch', 'joblib', 'pickle'
 581                ]):
 582                    has_ml_imports = True
 583                    break
 584
 585        if not has_ml_imports:
 586            return
 587
 588        # Check for model training/fitting first
 589        has_model_training = False
 590        for call in node.nodes_of_class(nodes.Call):
 591            if any(term in call.func.as_string() for term in ['fit', 'train']):
 592                has_model_training = True
 593                break
 594
 595        if not has_model_training:
 596            return
 597
 598        for call in node.nodes_of_class(nodes.Call):
 599            # Check for model saving operations
 600            if any(save in call.func.as_string() for save in [
 601                'save', 'dump', 'to_pickle', 'save_model', 'save_weights',
 602                'torch.save', 'joblib.dump'
 603            ]):
 604                model_save = True
 605
 606                # Check for preprocessing steps being saved
 607                if any(prep in call.as_string().lower() for prep in [
 608                    'scaler', 'encoder', 'preprocessor', 'pipeline',
 609                    'transform', 'processor', 'tokenizer'
 610                ]):
 611                    preprocessing_save = True
 612
 613                # Check for version control
 614                if any(ver in call.as_string().lower() for ver in [
 615                    'version', 'v1', 'v2', 'v3', 'timestamp', 'date',
 616                    '_v', '.v', 'release'
 617                ]):
 618                    version_control = True
 619
 620                # Check for version control in variable names
 621                if isinstance(call.func, nodes.Attribute):
 622                    if any(ver in str(call.args) + str(call.keywords) for ver in [
 623                        'version', 'timestamp', 'date', 'release'
 624                    ]):
 625                        version_control = True
 626
 627        # Only raise warnings if we have substantial model operations
 628        ml_operations = sum(1 for call in node.nodes_of_class(nodes.Call)
 629                            if any(term in call.func.as_string()
 630                                   for term in ['fit', 'train', 'predict', 'transform']))
 631
 632        if ml_operations >= 2 and not any(term in file_path.lower() for term in ['quick', 'example', 'demo']):
 633            if model_save and not preprocessing_save:
 634                self.add_smell(
 635                    "Model saving detected without preprocessing steps. "
 636                    "Remember to save preprocessing steps for proper model deployment.",
 637                    node, file_path
 638                )
 639            elif model_save and not version_control:
 640                self.add_smell(
 641                    "Model saving detected without clear versioning. "
 642                    "Consider adding version control for model artifacts.",
 643                    node, file_path
 644                )
 645
 646    def check_reproducibility(self, node: nodes.Module, file_path: str):
 647        """Check if random seeds are properly set for reproducibility.
 648
 649        Ensures random seeds are set for all relevant libraries (numpy, random,
 650        framework-specific) in ML operations.
 651
 652        Args:
 653            node: AST node representing the module
 654            file_path: Path to the file being analyzed
 655        """
 656        # Skip if this is likely not a training file
 657        if any(pattern in file_path.lower() for pattern in [
 658            'test_', 'utils', 'helper', 'data', 'preprocess',
 659            'explore', 'analyze', 'visualize', 'inference'
 660        ]):
 661            return
 662
 663        seed_methods = {
 664            'random_state', 'seed', 'torch.manual_seed', 'tf.random.set_seed',
 665            'np.random.seed', 'random.seed', 'cuda.manual_seed',
 666            'cuda.manual_seed_all', 'tensorflow.random.set_seed'
 667        }
 668        seeds_set = set()
 669        has_ml_operations = False
 670        has_ml_imports = False
 671
 672        # Check imports first
 673        for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
 674            if isinstance(import_node, nodes.ImportFrom):
 675                if any(pkg in import_node.modname for pkg in [
 676                    'sklearn', 'tensorflow', 'torch', 'numpy', 'random'
 677                ]):
 678                    has_ml_imports = True
 679                    break
 680
 681        if not has_ml_imports:
 682            return
 683
 684        for call in node.nodes_of_class(nodes.Call):
 685            # Check if file contains substantial ML operations
 686            if any(op in call.func.as_string() for op in [
 687                'fit', 'train', 'predict', 'transform', 'split',
 688                'sample', 'shuffle', 'random'
 689            ]):
 690                has_ml_operations = True
 691
 692            # Check for seed setting
 693            for seed_method in seed_methods:
 694                if seed_method in call.as_string():
 695                    seeds_set.add(seed_method)
 696
 697            # Check for seed parameters in ML operations
 698            if isinstance(call.func, nodes.Attribute):
 699                if any(seed in str(call.args) + str(call.keywords) for seed in [
 700                    'random_state', 'seed', 'deterministic'
 701                ]):
 702                    seeds_set.add('parameter_seed')
 703
 704        # Count substantial ML operations
 705        ml_operations = sum(1 for call in node.nodes_of_class(nodes.Call)
 706                            if any(term in call.func.as_string()
 707                                   for term in ['fit', 'train', 'predict', 'transform']))
 708
 709        if (ml_operations >= 2 and has_ml_operations and
 710                not any(term in file_path.lower() for term in ['quick', 'example', 'demo'])):
 711            if not seeds_set:
 712                self.add_smell(
 713                    "No random seed setting detected in ML operations. "
 714                    "Consider setting seeds for reproducibility.",
 715                    node, file_path
 716                )
 717            elif len(seeds_set) < 2 and any(framework in self.imports for framework in ['tensorflow', 'torch']):
 718                self.add_smell(
 719                    "Incomplete seed setting detected. Remember to set seeds for all "
 720                    "relevant libraries (numpy, random, framework-specific).",
 721                    node, file_path
 722                )
 723
 724    def check_data_loading(self, node: nodes.Module, file_path: str):
 725        """Analyze data loading practices for potential issues.
 726
 727        Checks for proper handling of large datasets, including batch processing
 728        and file size checks.
 729
 730        Args:
 731            node: AST node representing the module
 732            file_path: Path to the file being analyzed
 733        """
 734        # Skip if this is likely not a data loading file
 735        if any(pattern in file_path.lower() for pattern in [
 736            'test_', 'utils', 'helper', 'model', 'train',
 737            'evaluate', 'predict', 'inference'
 738        ]):
 739            return
 740
 741        data_loading_methods = {
 742            'read_csv', 'load_data', 'read_excel', 'read_parquet',
 743            'read_json', 'read_sql', 'load_dataset'
 744        }
 745        batch_processing_methods = {
 746            'batch_size', 'generator', 'DataLoader', 'dataset',
 747            'chunk', 'iterator', 'yield', 'flow_from_directory'
 748        }
 749        memory_handling_methods = {
 750            'dask', 'vaex', 'datatable', 'memory_limit',
 751            'low_memory', 'nrows', 'usecols'
 752        }
 753
 754        file_size_check = False
 755        batch_processing = False
 756        memory_handling = False
 757        has_data_imports = False
 758
 759        # Check imports first
 760        for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
 761            if isinstance(import_node, nodes.ImportFrom):
 762                if any(pkg in import_node.modname for pkg in [
 763                    'pandas', 'dask', 'vaex', 'torch.utils.data',
 764                    'tensorflow.data', 'datasets'
 765                ]):
 766                    has_data_imports = True
 767                    break
 768
 769        if not has_data_imports:
 770            return
 771
 772        for call in node.nodes_of_class(nodes.Call):
 773            if any(method in call.func.as_string() for method in data_loading_methods):
 774                # Check for file size checks
 775                if any(check in call.as_string() for check in [
 776                    'os.path.getsize', 'file_size', 'stat', 'memory_usage'
 777                ]):
 778                    file_size_check = True
 779
 780                # Check for batch processing
 781                if any(method in call.as_string() for method in batch_processing_methods):
 782                    batch_processing = True
 783
 784                # Check for memory handling
 785                if any(method in call.as_string() for method in memory_handling_methods):
 786                    memory_handling = True
 787
 788                # Check for parameters indicating memory consideration
 789                if isinstance(call.func, nodes.Attribute):
 790                    if any(param in str(call.args) + str(call.keywords) for param in [
 791                        'chunksize', 'batch_size', 'iterator', 'memory',
 792                        'nrows', 'usecols', 'dtype'
 793                    ]):
 794                        memory_handling = True
 795
 796        # Only raise warning if we have substantial data loading operations
 797        data_operations = sum(1 for call in node.nodes_of_class(nodes.Call)
 798                              if any(term in call.func.as_string()
 799                                     for term in ['read_', 'load_', 'open']))
 800
 801        if (data_operations >= 2 and
 802            not any(term in file_path.lower() for term in ['quick', 'example', 'demo']) and
 803                not (file_size_check or batch_processing or memory_handling)):
 804            self.add_smell(
 805                "Data loading detected without size checks or batch processing. "
 806                "Consider using generators or batch processing for large datasets.",
 807                node, file_path
 808            )
 809
 810    def check_unused_features(self, node: nodes.Module, file_path: str):
 811        """Detect potentially unused features or variables in ML code.
 812
 813        Identifies variables that are defined but not used, excluding common
 814        variable names and special prefixes.
 815
 816        Args:
 817            node: AST node representing the module
 818            file_path: Path to the file being analyzed
 819        """
 820        # Skip if this is likely not a main code file
 821        if any(pattern in file_path.lower() for pattern in [
 822            'test_', 'conftest', 'setup', '__init__',
 823            'utils', 'helper', 'config', 'constants'
 824        ]):
 825            return
 826
 827        features = set()
 828        used_features = set()
 829        # Expanded list of common variables to ignore
 830        common_vars = {
 831            'self', 'i', 'j', 'k', 'x', 'y', 'X', 'y', 'df', 'data',
 832            'model', 'clf', 'reg', 'pred', 'proba', 'score',
 833            'train', 'test', 'val', 'valid', 'result', 'output',
 834            'input', 'params', 'args', 'kwargs', 'config', 'options'
 835        }
 836
 837        # Skip prefixes for variables that are commonly used in specific ways
 838        skip_prefixes = {
 839            '_', 'temp_', 'tmp_', 'test_', 'debug_', 'log_',
 840            'cache_', 'old_', 'new_', 'raw_', 'processed_'
 841        }
 842
 843        # Skip suffixes that indicate special usage
 844        skip_suffixes = {
 845            '_id', '_idx', '_index', '_key', '_val', '_list',
 846            '_dict', '_map', '_set', '_array', '_df', '_series'
 847        }
 848
 849        has_ml_imports = False
 850        # Check for ML-related imports
 851        for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
 852            if isinstance(import_node, nodes.ImportFrom):
 853                if any(pkg in import_node.modname for pkg in [
 854                    'sklearn', 'tensorflow', 'torch', 'pandas', 'numpy'
 855                ]):
 856                    has_ml_imports = True
 857                    break
 858
 859        if not has_ml_imports:
 860            return
 861
 862        for assign in node.nodes_of_class(nodes.Assign):
 863            if isinstance(assign.targets[0], nodes.Name):
 864                name = assign.targets[0].name
 865                # Skip if name matches any exclusion criteria
 866                if (name not in common_vars and
 867                    not any(name.startswith(prefix) for prefix in skip_prefixes) and
 868                        not any(name.endswith(suffix) for suffix in skip_suffixes)):
 869                    features.add(name)
 870
 871        # Collect used features from various contexts
 872        for name in node.nodes_of_class(nodes.Name):
 873            used_features.add(name.name)
 874
 875        # Check for usage in attributes
 876        for attr in node.nodes_of_class(nodes.Attribute):
 877            if isinstance(attr.expr, nodes.Name):
 878                used_features.add(attr.expr.name)
 879
 880        unused = features - used_features - common_vars
 881        if unused:
 882            # Additional checks for usage in various contexts
 883            for node_type in [nodes.Const, nodes.Dict, nodes.List, nodes.Set]:
 884                for item in node.nodes_of_class(node_type):
 885                    if isinstance(item.value, str):
 886                        unused = unused - {feat for feat in unused if feat in item.value}
 887
 888            # Check for usage in f-strings
 889            for string in node.nodes_of_class(nodes.JoinedStr):
 890                unused = unused - {feat for feat in unused if feat in string.as_string()}
 891
 892            # Only report if we have substantial ML operations
 893            ml_operations = sum(1 for call in node.nodes_of_class(nodes.Call)
 894                                if any(term in call.func.as_string()
 895                                       for term in ['fit', 'predict', 'transform']))
 896
 897            if unused and ml_operations >= 2:
 898                self.add_smell(
 899                    f"Potentially unused features detected: {', '.join(unused)}. "
 900                    "Verify if these are actually needed.",
 901                    node, file_path
 902                )
 903
 904    def check_overfit_prone_practices(self, node: nodes.Module, file_path: str):
 905        """Detect practices that might lead to overfitting.
 906
 907        Checks feature engineering functions for proper train/test separation
 908        to avoid data leakage.
 909
 910        Args:
 911            node: AST node representing the module
 912            file_path: Path to the file being analyzed
 913        """
 914        # Skip if this is likely not a feature engineering file
 915        if any(pattern in file_path.lower() for pattern in [
 916            'test_', 'utils', 'helper', 'config', 'constants',
 917            'visualization', 'plotting', 'display', 'logging'
 918        ]):
 919            return
 920
 921        has_ml_imports = False
 922        # Check for ML-related imports
 923        for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
 924            if isinstance(import_node, nodes.ImportFrom):
 925                if any(pkg in import_node.modname for pkg in [
 926                    'sklearn', 'pandas', 'numpy', 'tensorflow', 'torch'
 927                ]):
 928                    has_ml_imports = True
 929                    break
 930
 931        if not has_ml_imports:
 932            return
 933
 934        for func in node.nodes_of_class(nodes.FunctionDef):
 935            # Check if it's a feature engineering function
 936            if any(term in func.name.lower() for term in [
 937                'feature', 'transform', 'process', 'engineer', 'prepare'
 938            ]):
 939                # Skip if it's clearly safe
 940                if any(safe_term in func.name.lower() for safe_term in [
 941                    'train', 'fit', 'training_only', 'train_set',
 942                    'single', 'one', 'individual', 'row'
 943                ]):
 944                    continue
 945
 946                # Check function body for proper data handling
 947                has_safe_handling = False
 948                for call in func.nodes_of_class(nodes.Call):
 949                    if any(term in call.as_string().lower() for term in [
 950                        'train_test_split', 'validation', 'train_data', 'test_data',
 951                        'fit_transform', 'transform', 'partial_fit', 'single_transform'
 952                    ]):
 953                        has_safe_handling = True
 954                        break
 955
 956                # Check for parameter names indicating proper handling
 957                for arg in func.args.args:
 958                    if any(term in arg.name.lower() for term in [
 959                        'train', 'test', 'valid', 'single', 'row'
 960                    ]):
 961                        has_safe_handling = True
 962                        break
 963
 964                # Only warn if we have substantial data operations
 965                data_operations = sum(1 for call in func.nodes_of_class(nodes.Call)
 966                                      if any(term in call.func.as_string()
 967                                             for term in ['fit', 'transform', 'process']))
 968
 969                if not has_safe_handling and data_operations >= 2:
 970                    self.add_smell(
 971                        "Feature engineering function detected without clear train/test separation. "
 972                        "Ensure it's not applied to the entire dataset to avoid data leakage.",
 973                        func, file_path
 974                    )
 975
 976    def check_error_handling(self, node: nodes.Module, file_path: str):
 977        """Check for proper error handling in critical ML operations.
 978
 979        Ensures try-except blocks or validation checks are present for data
 980        operations and model-related tasks.
 981
 982        Args:
 983            node: AST node representing the module
 984            file_path: Path to the file being analyzed
 985        """
 986        # Skip if this is likely not a main code file
 987        if any(pattern in file_path.lower() for pattern in [
 988            'test_', 'conftest', 'setup', '__init__',
 989            'utils', 'helper', 'config', 'constants',
 990            'visualization', 'plotting', 'display'
 991        ]):
 992            return
 993
 994        has_data_operations = False
 995        has_error_handling = False
 996        critical_operations = [
 997            'read_csv', 'load_data', 'open', 'fit', 'predict',
 998            'transform', 'save', 'dump', 'to_pickle', 'load_model',
 999            'read_excel', 'read_json', 'read_sql'
1000        ]
1001
1002        has_ml_imports = False
1003        # Check for relevant imports
1004        for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
1005            if isinstance(import_node, nodes.ImportFrom):
1006                if any(pkg in import_node.modname for pkg in [
1007                    'pandas', 'sklearn', 'tensorflow', 'torch',
1008                    'pickle', 'joblib'
1009                ]):
1010                    has_ml_imports = True
1011                    break
1012
1013        if not has_ml_imports:
1014            return
1015
1016        # Check for critical operations
1017        critical_op_count = 0
1018        for call in node.nodes_of_class(nodes.Call):
1019            if any(op in call.func.as_string() for op in critical_operations):
1020                has_data_operations = True
1021                critical_op_count += 1
1022
1023        if has_data_operations:
1024            # Check for different types of error handling
1025            for block in node.nodes_of_class((nodes.Try, nodes.ExceptHandler)):
1026                has_error_handling = True
1027                break
1028
1029            # Check for validation checks
1030            for if_block in node.nodes_of_class(nodes.If):
1031                if any(check in if_block.as_string().lower() for check in [
1032                    'isinstance', 'isfile', 'exists', 'shape', 'empty', 'null',
1033                    'none', 'isnull', 'isna', 'type', 'hasattr', 'in',
1034                    'validate', 'check', 'verify'
1035                ]):
1036                    has_error_handling = True
1037                    break
1038
1039            # Check for assertion statements
1040            for assert_node in node.nodes_of_class(nodes.Assert):
1041                has_error_handling = True
1042                break
1043
1044            if not has_error_handling and critical_op_count >= 2:
1045                self.add_smell(
1046                    "No error handling detected in data processing. "
1047                    "Consider adding try-except blocks or validation checks for robustness.",
1048                    node, file_path
1049                )
1050
1051    def check_hardcoded_filepaths(self, node: nodes.Module, file_path: str):
1052        """Detect hardcoded file paths in the code.
1053
1054        Identifies hardcoded paths that should be moved to configuration files
1055        or environment variables, excluding common development paths.
1056
1057        Args:
1058            node: AST node representing the module
1059            file_path: Path to the file being analyzed
1060        """
1061        # Common acceptable patterns
1062        acceptable_patterns = [
1063            './test/', './tests/',
1064            '../test/', '../tests/',
1065            'fixtures/', 'data/test/',
1066            '__pycache__', '.git/',
1067            'venv/', 'env/'
1068        ]
1069
1070        config_vars = set()
1071        # First collect any config/environment variables
1072        for assign in node.nodes_of_class(nodes.Assign):
1073            if isinstance(assign.targets[0], nodes.Name):
1074                if any(config_term in assign.targets[0].name.lower() for config_term in
1075                       ['path', 'dir', 'folder', 'file', 'config']):
1076                    config_vars.add(assign.targets[0].name)
1077
1078        for string in node.nodes_of_class(nodes.Const):
1079            if isinstance(string.value, str) and ('/' in string.value or '\\' in string.value):
1080                # Skip if it's a test/common development path
1081                if any(pattern in string.value for pattern in acceptable_patterns):
1082                    continue
1083
1084                # Skip if it's used in a config/path variable assignment
1085                if any(config_var in string.scope().locals for config_var in config_vars):
1086                    continue
1087
1088                self.add_smell(
1089                    f"Hardcoded file path detected: {string.value}. "
1090                    "Consider using configuration files or environment variables.",
1091                    string, file_path
1092                )
1093
1094    def check_documentation(self, node: nodes.Module, file_path: str):
1095        """Check for proper documentation in ML-related code.
1096
1097        Ensures functions and classes have appropriate docstrings, especially
1098        for those with parameters or return values.
1099
1100        Args:
1101            node: AST node representing the module
1102            file_path: Path to the file being analyzed
1103        """
1104        # Skip certain types of files/functions
1105        skip_patterns = [
1106            'test_', 'fixture', 'conftest',
1107            'setup', 'init', 'main',
1108            'helper', 'util'
1109        ]
1110
1111        for func in node.nodes_of_class(nodes.FunctionDef):
1112            # Skip if it's a simple function (few lines)
1113            if len(list(func.get_children())) <= 3:
1114                continue
1115
1116            # Skip if it's a test or utility function
1117            if any(pattern in func.name.lower() for pattern in skip_patterns):
1118                continue
1119
1120            if not isinstance(func.doc_node, nodes.Const):
1121                # Only warn for functions with parameters or return values
1122                if func.args.args or 'return' in func.as_string():
1123                    self.add_smell(
1124                        f"Missing docstring for function: {func.name}. "
1125                        "Consider adding documentation for parameters and return values.",
1126                        func, file_path
1127                    )
1128
1129        for cls in node.nodes_of_class(nodes.ClassDef):
1130            # Skip test classes
1131            if any(pattern in cls.name.lower() for pattern in skip_patterns):
1132                continue
1133
1134            if not isinstance(cls.doc_node, nodes.Const):
1135                # Check if class has public methods
1136                has_public_methods = any(
1137                    not method.name.startswith('_')
1138                    for method in cls.mymethods()
1139                )
1140                if has_public_methods:
1141                    self.add_smell(
1142                        f"Missing docstring for class: {cls.name}. "
1143                        "Consider adding class-level documentation.",
1144                        cls, file_path
1145                    )
1146
1147    def generate_report(self) -> str:
1148        """Generate a human-readable report of all detected smells.
1149
1150        Returns:
1151            Formatted string containing the analysis report
1152        """
1153        report = "General ML Code Smell Report\n============================\n\n"
1154        smell_counts = {}
1155        for i, smell in enumerate(self.smells, 1):
1156            if smell['smell'] not in smell_counts:
1157                smell_counts[smell['smell']] = 0
1158            smell_counts[smell['smell']] += 1
1159
1160            report += f"{i}. Smell: {smell['smell']}\n"
1161            report += f"   File: {smell['file_path']}\n"
1162
1163            # Only show line number if it's not 0
1164            if smell['line_number'] != 0:
1165                report += f"   Line: {smell['line_number']}\n"
1166
1167            # Only include code snippet if it's 3 lines or fewer
1168            code_lines = smell['code_snippet'].strip().split('\n')
1169            if len(code_lines) <= 3:
1170                report += f"   Code Snippet:\n{smell['code_snippet']}\n"
1171            report += "\n"
1172
1173        report += "Smell Counts:\n"
1174        for smell, count in smell_counts.items():
1175            report += f"  {smell}: {count}\n"
1176        report += f"\nTotal smells detected: {len(self.smells)}"
1177        return report
1178
1179    def get_results(self) -> List[Dict[str, str]]:
1180        """Get the analysis results in a structured format.
1181
1182        Returns:
1183            List of dictionaries containing smell details and locations
1184        """
1185        return [
1186            {
1187                'framework': 'General ML',
1188                'name': smell['smell'],
1189                'fix': "Not specified",
1190                'benefits': "Not specified",
1191                'location': f"Line {smell['line_number']}" if smell['line_number'] != 0 else ""
1192            }
1193            for smell in self.smells
1194        ]