Framework-Specific Smell Detector

Source Code

   1import os
   2import sys
   3from typing import Any, Dict, List
   4
   5import astroid
   6from astroid import nodes
   7
   8
   9class FrameworkSpecificSmellDetector:
  10    """A detector for identifying code smells specific to ML frameworks like
  11    Pandas, NumPy, Scikit-learn, PyTorch and TensorFlow.
  12
  13    This class analyzes Python code to detect common anti-patterns and suboptimal implementations
  14    when using popular machine learning frameworks. It provides suggestions for improvements
  15    and best practices.
  16    """
  17
  18    def __init__(self):
  19        self.smells: List[Dict[str, Any]] = []
  20        self.framework_smells = self.get_smells()
  21
  22    def detect_smells(self, file_path: str) -> List[Dict[str, str]]:
  23        """Analyze a Python file for framework-specific code smells.
  24
  25        Args:
  26            file_path: Path to the Python file to analyze
  27
  28        Returns:
  29            List of dictionaries containing detected code smells with details like:
  30            - framework: The ML framework the smell relates to
  31            - name: Name of the code smell
  32            - how_to_fix: Instructions for fixing the issue
  33            - benefits: Benefits of fixing the issue
  34            - line_number: Line where the smell was detected
  35            - code_snippet: The problematic code
  36            - file_path: Path to the file containing the smell
  37        """
  38        try:
  39            with open(file_path, 'r') as file:
  40                content = file.read()
  41            module_name = os.path.splitext(os.path.basename(file_path))[0]
  42            module = astroid.parse(content, module_name=module_name)
  43
  44            if not module:
  45                print(f"Error: Could not parse module for {file_path}", file=sys.stderr)
  46                return self.smells
  47
  48            frameworks_used = self.get_frameworks_used(module)
  49            if frameworks_used is None:
  50                print(f"Error: Could not determine frameworks for {file_path}", file=sys.stderr)
  51                return self.smells
  52
  53            if frameworks_used:
  54                self.visit_module(module, file_path, frameworks_used)
  55            else:
  56                print(
  57                    f"Skipping framework-specific smell detection for {file_path}: No relevant frameworks imported",
  58                    file=sys.stderr)
  59        except astroid.exceptions.AstroidSyntaxError as e:
  60            print(f"Error parsing {file_path}: {str(e)}", file=sys.stderr)
  61        except Exception as e:
  62            print(f"Unexpected error while processing {file_path}: {str(e)}", file=sys.stderr)
  63        return self.smells
  64
  65    def get_frameworks_used(self, node: nodes.Module) -> List[str]:
  66        """Detect which ML frameworks are imported in the code.
  67
  68        Args:
  69            node: AST node representing the Python module
  70
  71        Returns:
  72            List of framework names found in imports (e.g. ['Pandas', 'NumPy'])
  73        """
  74        if not node:
  75            return []
  76
  77        frameworks = []
  78        framework_imports = {
  79            'pandas': 'Pandas',
  80            'numpy': 'NumPy',
  81            'sklearn': 'ScikitLearn',
  82            'tensorflow': 'TensorFlow',
  83            'torch': 'PyTorch'
  84        }
  85
  86        try:
  87            for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
  88                if isinstance(import_node, nodes.Import):
  89                    for name, _ in import_node.names:
  90                        root = name.split('.')[0]
  91                        if root in framework_imports and framework_imports[root] not in frameworks:
  92                            frameworks.append(framework_imports[root])
  93                elif isinstance(import_node, nodes.ImportFrom):
  94                    root = import_node.modname.split('.')[0]
  95                    if root in framework_imports and framework_imports[root] not in frameworks:
  96                        frameworks.append(framework_imports[root])
  97        except Exception as e:
  98            print(f"Error in get_frameworks_used: {str(e)}", file=sys.stderr)
  99            return []
 100
 101        return frameworks
 102
 103    def visit_module(self, node: nodes.Module, file_path: str, frameworks_used: List[str]):
 104        """Visit a Python module to detect framework-specific smells.
 105
 106        Args:
 107            node: AST node representing the Python module
 108            file_path: Path to the file being analyzed
 109            frameworks_used: List of frameworks detected in the code
 110        """
 111        if not node or not frameworks_used:
 112            return
 113
 114        try:
 115            self.check_imports(node, file_path)
 116            for framework in frameworks_used:
 117                if framework:
 118                    detector_method = getattr(self, f"detect_{framework.lower()}_smells", None)
 119                    if detector_method:
 120                        detector_method(node, file_path)
 121        except Exception as e:
 122            print(f"Error in visit_module for {file_path}: {str(e)}", file=sys.stderr)
 123
 124    def add_smell(self, framework: str, smell_name: str, node: nodes.NodeNG, file_path: str):
 125        """Add a detected code smell to the results.
 126
 127        Args:
 128            framework: Name of the framework (e.g. 'Pandas', 'NumPy')
 129            smell_name: Name of the detected smell
 130            node: AST node where the smell was found
 131            file_path: Path to the file containing the smell
 132        """
 133        smell = next((s for s in self.framework_smells[framework] if s['name'] == smell_name), None)
 134        if smell:
 135            self.smells.append({
 136                "framework": framework,
 137                "name": smell['name'],
 138                "how_to_fix": smell['how_to_fix'],
 139                "benefits": smell['benefits'],
 140                "strategies": smell['strategies'],
 141                "line_number": node.lineno,
 142                "code_snippet": node.as_string(),
 143                "file_path": file_path
 144            })
 145
 146    # Pandas Detection Methods
 147    def detect_pandas_smells(self, node: nodes.Module, file_path: str):
 148        """Detect Pandas-specific code smells like:
 149        - Unnecessary iteration instead of vectorization
 150        - Chain indexing
 151        - Missing merge parameters
 152        - Inplace operations
 153        - Inefficient values usage
 154        - Missing dtype specifications
 155        - Suboptimal column selection
 156        - DataFrame modifications in loops
 157
 158        Args:
 159            node: AST node representing the Python module
 160            file_path: Path to the file being analyzed
 161        """
 162        self.detect_iterrows_usage(node, file_path)
 163        self.detect_chain_indexing(node, file_path)
 164        self.detect_merge_parameters(node, file_path)
 165        self.detect_inplace_operations(node, file_path)
 166        self.detect_values_usage(node, file_path)
 167        self.detect_dtype_specification(node, file_path)
 168        self.detect_column_selection(node, file_path)
 169        self.detect_dataframe_modification(node, file_path)
 170
 171    def detect_iterrows_usage(self, node: nodes.Module, file_path: str):
 172        # Only flag iterrows if better alternatives could be used
 173        vectorizable_operations = ['sum', 'mean', 'max', 'min', 'apply', 'map']
 174        for call in node.nodes_of_class(nodes.Call):
 175            if 'iterrows' in call.func.as_string():
 176                # Check if the iterrows is used within a loop
 177                parent = call.parent
 178                while parent and not isinstance(parent, nodes.For):
 179                    parent = parent.parent
 180                if parent and any(op in parent.as_string() for op in vectorizable_operations):
 181                    self.add_smell('Pandas', 'Unnecessary Iteration', call, file_path)
 182
 183    def detect_chain_indexing(self, node: nodes.Module, file_path: str):
 184        for subscript in node.nodes_of_class(nodes.Subscript):
 185            if isinstance(subscript.value, nodes.Subscript):
 186                # Only flag if it's not within a with pd.option_context block
 187                parent = subscript.parent
 188                while parent and not isinstance(parent, nodes.With):
 189                    parent = parent.parent
 190                if not parent or 'option_context' not in parent.as_string():
 191                    # Check if it's an assignment (more dangerous) vs just access
 192                    if isinstance(subscript.parent, nodes.Assign):
 193                        self.add_smell('Pandas', 'Chain Indexing', subscript, file_path)
 194
 195    def detect_merge_parameters(self, node: nodes.Module, file_path: str):
 196        for call in node.nodes_of_class(nodes.Call):
 197            if ('merge' in call.func.as_string() and
 198                isinstance(call.func, nodes.Attribute) and
 199                # Check if it's actually a pandas merge
 200                ('pd' in call.func.expr.as_string() or 'pandas' in call.func.expr.as_string()) and
 201                    not any(kw.arg in ['how', 'on', 'validate'] for kw in call.keywords)):
 202                self.add_smell('Pandas', 'Merge Parameter Checker', call, file_path)
 203
 204    def detect_inplace_operations(self, node: nodes.Module, file_path: str):
 205        inplace_operations = ['sort_values', 'fillna', 'drop', 'replace', 'rename']
 206        for call in node.nodes_of_class(nodes.Call):
 207            if ('inplace=True' in call.as_string() and
 208                    any(op in call.func.as_string() for op in inplace_operations)):
 209                # Check if the result is used later
 210                parent = call.parent
 211                while parent and not isinstance(parent, nodes.Module):
 212                    if isinstance(parent, nodes.Assign):
 213                        break
 214                    parent = parent.parent
 215                if not isinstance(parent, nodes.Assign):
 216                    self.add_smell('Pandas', 'InPlace Checker', call, file_path)
 217
 218    def detect_values_usage(self, node: nodes.Module, file_path: str):
 219        for call in node.nodes_of_class(nodes.Call):
 220            if '.values' in call.as_string():
 221                # Check if it's used in a context where to_numpy() would be better
 222                numpy_contexts = ['np.', 'array', 'asarray', 'reshape', 'transpose']
 223                if any(ctx in call.as_string() for ctx in numpy_contexts):
 224                    self.add_smell('Pandas', 'DataFrame Conversion Checker', call, file_path)
 225
 226    def detect_dtype_specification(self, node: nodes.Module, file_path: str):
 227        for call in node.nodes_of_class(nodes.Call):
 228            if ('read_csv' in call.func.as_string() and
 229                # Check if it's actually pandas read_csv
 230                ('pd' in call.func.as_string() or 'pandas' in call.func.as_string()) and
 231                    not any(kw.arg in ['dtype', 'parse_dates'] for kw in call.keywords)):
 232                # Check if the DataFrame is used in operations sensitive to dtypes
 233                parent = call.parent
 234                dtype_sensitive_ops = ['groupby', 'merge', 'join', 'sort_values', 'arithmetic_operations']
 235                while parent and not isinstance(parent, nodes.Module):
 236                    if any(op in parent.as_string() for op in dtype_sensitive_ops):
 237                        self.add_smell('Pandas', 'Datatype Checker', call, file_path)
 238                        break
 239                    parent = parent.parent
 240
 241    def detect_column_selection(self, node: nodes.Module, file_path: str):
 242        # Only check if there are actual DataFrame operations
 243        df_operations = False
 244        column_selections = False
 245
 246        for subscript in node.nodes_of_class(nodes.Subscript):
 247            if isinstance(subscript.value, nodes.Name):
 248                # Look for DataFrame operations
 249                if any(op in node.as_string() for op in ['DataFrame', 'pd.', 'pandas']):
 250                    df_operations = True
 251                    if '[[' in subscript.as_string():
 252                        column_selections = True
 253                        break
 254
 255        if df_operations and not column_selections:
 256            self.add_smell('Pandas', 'Column Selection Checker', node, file_path)
 257
 258    def detect_dataframe_modification(self, node: nodes.Module, file_path: str):
 259        for assign in node.nodes_of_class(nodes.Assign):
 260            if (isinstance(assign.targets[0], nodes.Subscript) and
 261                    isinstance(assign.targets[0].value, nodes.Name)):
 262                # Check if it's within a loop and modifying a DataFrame
 263                parent = assign.parent
 264                while parent and not isinstance(parent, nodes.For):
 265                    parent = parent.parent
 266
 267                # Check if it's actually a DataFrame modification
 268                df_indicators = ['DataFrame', 'pd.', 'pandas']
 269                if (parent and
 270                    any(indicator in assign.as_string() for indicator in df_indicators) and
 271                        not any(safe_op in assign.as_string() for safe_op in ['loc', 'iloc', 'at', 'iat'])):
 272                    self.add_smell('Pandas', 'DataFrame Iteration Modification', assign, file_path)
 273
 274    # NumPy Detection Methods
 275    def detect_numpy_smells(self, node: nodes.Module, file_path: str):
 276        """Detect NumPy-specific code smells like:
 277        - NaN equality comparisons
 278        - Missing random seeds
 279        - Inefficient array creation
 280        - Non-vectorized operations
 281        - Dtype inconsistencies
 282        - Broadcasting issues
 283        - Copy/view confusion
 284        - Missing axis specifications
 285
 286        Args:
 287            node: AST node representing the Python module
 288            file_path: Path to the file being analyzed
 289        """
 290        self.detect_nan_equality(node, file_path)
 291        self.detect_random_seed(node, file_path)
 292        self.detect_array_creation(node, file_path)
 293        self.detect_inefficient_operations(node, file_path)
 294        self.detect_dtype_consistency(node, file_path)
 295        self.detect_broadcasting_issues(node, file_path)
 296        self.detect_copy_view_issues(node, file_path)
 297        self.detect_axis_specification(node, file_path)
 298
 299    def detect_nan_equality(self, node: nodes.Module, file_path: str):
 300        for compare in node.nodes_of_class(nodes.Compare):
 301            if any('np.nan' in getattr(op[1], 'as_string', lambda: '')() for op in compare.ops):
 302                self.add_smell('NumPy', 'NaN Equality Checker', compare, file_path)
 303
 304    def detect_random_seed(self, node: nodes.Module, file_path: str):
 305        # List of numpy random operations to check for
 306        random_operations = [
 307            'np.random.rand',
 308            'np.random.randn',
 309            'np.random.randint',
 310            'np.random.choice',
 311            'np.random.shuffle',
 312            'np.random.permutation',
 313            'np.random.normal',
 314            'np.random.uniform'
 315        ]
 316
 317        # Find all random operation calls
 318        random_op_calls = [
 319            call for call in node.nodes_of_class(nodes.Call)
 320            if any(op in getattr(call.func, 'as_string', lambda: '')() for op in random_operations)
 321        ]
 322
 323        # Find all random seed calls
 324        random_seed_calls = [
 325            call for call in node.nodes_of_class(nodes.Call)
 326            if 'np.random.seed' in getattr(call.func, 'as_string', lambda: '')()
 327        ]
 328
 329        # Only trigger smell if random operations are used without seed
 330        if random_op_calls and not random_seed_calls:
 331            self.add_smell('NumPy', 'Randomness Control Checker', node, file_path)
 332
 333    def detect_array_creation(self, node: nodes.Module, file_path: str):
 334        """Detect inefficient array creation patterns"""
 335        for call in node.nodes_of_class(nodes.Call):
 336            # Check for list to array conversion without dtype
 337            if ('np.array' in call.func.as_string() and
 338                    not any(kw.arg == 'dtype' for kw in call.keywords)):
 339                self.add_smell('NumPy', 'Array Creation Efficiency', call, file_path)
 340
 341            # Check for zeros/ones/empty without dtype
 342            if any(func in call.func.as_string() for func in ['np.zeros', 'np.ones', 'np.empty']) and \
 343               not any(kw.arg == 'dtype' for kw in call.keywords):
 344                self.add_smell('NumPy', 'Array Creation Efficiency', call, file_path)
 345
 346    def detect_inefficient_operations(self, node: nodes.Module, file_path: str):
 347        """Detect inefficient numerical operations"""
 348        for loop in node.nodes_of_class(nodes.For):
 349            # Check for element-wise operations in loops
 350            if any(op in loop.as_string() for op in ['np.sum', 'np.mean', 'np.max', 'np.min']):
 351                self.add_smell('NumPy', 'Inefficient Operations', loop, file_path)
 352
 353        # Check for inefficient concatenation
 354        for call in node.nodes_of_class(nodes.Call):
 355            if 'np.concatenate' in call.func.as_string():
 356                parent = call.parent
 357                while parent and not isinstance(parent, nodes.For):
 358                    parent = parent.parent
 359                if parent:  # If concatenate is inside a loop
 360                    self.add_smell('NumPy', 'Inefficient Operations', call, file_path)
 361
 362    def detect_dtype_consistency(self, node: nodes.Module, file_path: str):
 363        """Detect potential dtype inconsistency issues"""
 364        for binop in node.nodes_of_class(nodes.BinOp):
 365            # Check for mixed integer and float operations
 366            if (isinstance(binop.left, nodes.Call) and isinstance(binop.right, nodes.Call) and
 367                    'np.' in binop.left.func.as_string() and 'np.' in binop.right.func.as_string()):
 368                if ('int' in binop.left.func.as_string() and 'float' in binop.right.func.as_string()) or \
 369                   ('float' in binop.left.func.as_string() and 'int' in binop.right.func.as_string()):
 370                    self.add_smell('NumPy', 'Dtype Consistency', binop, file_path)
 371
 372    def detect_broadcasting_issues(self, node: nodes.Module, file_path: str):
 373        """Detect potential broadcasting issues"""
 374        for binop in node.nodes_of_class(nodes.BinOp):
 375            if isinstance(binop.left, nodes.Call) and isinstance(binop.right, nodes.Call):
 376                # Check for operations between arrays that might have broadcasting issues
 377                if ('reshape' in binop.left.as_string() or 'reshape' in binop.right.as_string() or
 378                        'transpose' in binop.left.as_string() or 'transpose' in binop.right.as_string()):
 379                    self.add_smell('NumPy', 'Broadcasting Risk', binop, file_path)
 380
 381    def detect_copy_view_issues(self, node: nodes.Module, file_path: str):
 382        """Detect potential copy/view confusion in NumPy array operations"""
 383        for assign in node.nodes_of_class(nodes.Assign):
 384            # Only check assignments involving array slicing
 385            if isinstance(assign.value, nodes.Subscript):
 386                # Skip if it's not a NumPy array operation
 387                if not any(np_indicator in assign.as_string()
 388                           for np_indicator in ['np.', 'numpy.']):
 389                    continue
 390
 391                # Check if the slice is being modified later
 392                is_modified = False
 393                parent = assign.parent
 394                while parent and not isinstance(parent, nodes.Module):
 395                    if isinstance(parent, (nodes.Assign, nodes.AugAssign)):
 396                        # Look for modifications to the assigned variable
 397                        target_name = assign.targets[0].as_string()
 398                        if target_name in parent.as_string():
 399                            is_modified = True
 400                            break
 401                    parent = parent.parent
 402
 403                # Only flag if:
 404                # 1. The slice is modified later
 405                # 2. No explicit copy is made
 406                # 3. Not using advanced indexing (which creates copies)
 407                if (is_modified and
 408                    not any(method in assign.as_string() for method in ['.copy()', 'np.copy']) and
 409                        not any(idx in assign.value.as_string() for idx in ['[[', 'bool', 'mask'])):
 410                    self.add_smell('NumPy', 'Copy-View Confusion', assign, file_path)
 411
 412    def detect_axis_specification(self, node: nodes.Module, file_path: str):
 413        """Detect missing axis specifications in array operations"""
 414        axis_operations = ['sum', 'mean', 'max', 'min', 'argmax', 'argmin', 'any', 'all']
 415        for call in node.nodes_of_class(nodes.Call):
 416            if (any(op in call.func.as_string() for op in axis_operations) and
 417                'np.' in call.func.as_string() and
 418                not any(kw.arg == 'axis' for kw in call.keywords) and
 419                    len(call.args) < 2):  # axis can also be specified as positional argument
 420                self.add_smell('NumPy', 'Missing Axis Specification', call, file_path)
 421
 422    # Scikit-learn Detection Methods
 423    def detect_scikitlearn_smells(self, node: nodes.Module, file_path: str):
 424        """Detect Scikit-learn specific code smells like:
 425        - Missing data scaling
 426        - Not using pipelines
 427        - Missing cross validation
 428        - Missing random state
 429        - Missing verbose mode
 430        - Using only threshold-dependent metrics
 431        - Missing unit tests
 432        - Data leakage risks
 433        - Missing exception handling
 434
 435        Args:
 436            node: AST node representing the Python module
 437            file_path: Path to the file being analyzed
 438        """
 439        self.detect_scaling_usage(node, file_path)
 440        self.detect_pipeline_usage(node, file_path)
 441        self.detect_cross_validation(node, file_path)
 442        self.detect_random_state(node, file_path)
 443        self.detect_verbose_mode(node, file_path)
 444        self.detect_threshold_metrics(node, file_path)
 445        self.detect_unit_tests(node, file_path)
 446        self.detect_data_leakage(node, file_path)
 447        self.detect_exception_handling(node, file_path)
 448
 449    def detect_scaling_usage(self, node: nodes.Module, file_path: str):
 450        scaling_sensitive_estimators = [
 451            'SVM', 'SVR', 'PCA', 'KMeans', 'NeuralNetwork', 'LogisticRegression'
 452        ]
 453        scaling_methods = ['StandardScaler', 'MinMaxScaler', 'RobustScaler']
 454
 455        # Check if scaling-sensitive estimators are used
 456        has_sensitive_estimator = any(
 457            estimator in call.func.as_string()
 458            for call in node.nodes_of_class(nodes.Call)
 459            for estimator in scaling_sensitive_estimators
 460        )
 461
 462        # Check if any scaling is applied
 463        has_scaling = any(
 464            method in call.func.as_string()
 465            for call in node.nodes_of_class(nodes.Call)
 466            for method in scaling_methods
 467        )
 468
 469        # Only report if using sensitive estimators without scaling
 470        if has_sensitive_estimator and not has_scaling:
 471            self.add_smell('ScikitLearn', 'Scaler Missing Checker', node, file_path)
 472
 473    def detect_pipeline_usage(self, node: nodes.Module, file_path: str):
 474        # Check if there are multiple preprocessing or model fitting steps
 475        preprocessing_steps = [
 476            'StandardScaler', 'MinMaxScaler', 'PCA', 'SelectKBest',
 477            'PolynomialFeatures', 'OneHotEncoder', 'LabelEncoder'
 478        ]
 479        model_steps = [
 480            'fit', 'predict', 'transform', 'fit_transform'
 481        ]
 482
 483        has_preprocessing = any(
 484            step in call.func.as_string()
 485            for call in node.nodes_of_class(nodes.Call)
 486            for step in preprocessing_steps
 487        )
 488        has_model_steps = any(
 489            step in call.func.as_string()
 490            for call in node.nodes_of_class(nodes.Call)
 491            for step in model_steps
 492        )
 493
 494        # Only suggest Pipeline if multiple steps are present
 495        if has_preprocessing and has_model_steps and not any(
 496            'Pipeline' in call.func.as_string()
 497            for call in node.nodes_of_class(nodes.Call)
 498        ):
 499            self.add_smell('ScikitLearn', 'Pipeline Checker', node, file_path)
 500
 501    def detect_cross_validation(self, node: nodes.Module, file_path: str):
 502        # Check if model training is performed
 503        training_indicators = ['fit', 'train']
 504        has_training = any(
 505            indicator in call.func.as_string()
 506            for call in node.nodes_of_class(nodes.Call)
 507            for indicator in training_indicators
 508        )
 509
 510        # Check for any cross-validation technique
 511        cv_methods = [
 512            'cross_val_score', 'KFold', 'StratifiedKFold',
 513            'cross_validate', 'GridSearchCV', 'RandomizedSearchCV'
 514        ]
 515        has_cv = any(
 516            method in call.func.as_string()
 517            for call in node.nodes_of_class(nodes.Call)
 518            for method in cv_methods
 519        )
 520
 521        # Only suggest cross-validation for model training scenarios
 522        if has_training and not has_cv:
 523            self.add_smell('ScikitLearn', 'Cross Validation Checker', node, file_path)
 524
 525    def detect_random_state(self, node: nodes.Module, file_path: str):
 526        # List of methods that accept random_state
 527        random_state_methods = [
 528            'train_test_split', 'KFold', 'RandomForest', 'KMeans',
 529            'PCA', 'shuffle', 'random_state'
 530        ]
 531
 532        # Check if any random-dependent operations are used
 533        random_dependent_calls = [
 534            call for call in node.nodes_of_class(nodes.Call)
 535            if any(method in call.func.as_string() for method in random_state_methods)
 536        ]
 537
 538        # Check if random_state is set for these calls
 539        for call in random_dependent_calls:
 540            if not any(
 541                kw.arg == 'random_state' for kw in call.keywords
 542            ):
 543                self.add_smell('ScikitLearn', 'Randomness Control Checker', call, file_path)
 544
 545    def detect_verbose_mode(self, node: nodes.Module, file_path: str):
 546        # Only check for verbose in time-consuming operations
 547        time_consuming_ops = [
 548            'GridSearchCV', 'RandomizedSearchCV', 'fit',
 549            'RandomForest', 'GradientBoosting'
 550        ]
 551
 552        for call in node.nodes_of_class(nodes.Call):
 553            if (any(op in call.func.as_string() for op in time_consuming_ops) and
 554                    not any(kw.arg == 'verbose' for kw in call.keywords)):
 555                self.add_smell('ScikitLearn', 'Verbose Mode Checker', call, file_path)
 556
 557    def detect_threshold_metrics(self, node: nodes.Module, file_path: str):
 558        # Only check for classification-related tasks
 559        classification_indicators = [
 560            'classifier', 'predict_proba', 'accuracy_score',
 561            'precision_score', 'recall_score', 'f1_score',
 562            'classification_report'
 563        ]
 564
 565        # Check if it's a classification task
 566        is_classification = any(
 567            indicator in node.as_string()
 568            for indicator in classification_indicators
 569        )
 570
 571        # Check for threshold-independent metrics
 572        threshold_independent_metrics = [
 573            'roc_auc_score', 'average_precision_score',
 574            'precision_recall_curve', 'roc_curve'
 575        ]
 576        has_threshold_metrics = any(
 577            metric in call.func.as_string()
 578            for call in node.nodes_of_class(nodes.Call)
 579            for metric in threshold_independent_metrics
 580        )
 581
 582        # Only suggest threshold-independent metrics for classification tasks
 583        if is_classification and not has_threshold_metrics:
 584            self.add_smell('ScikitLearn', 'Dependent Threshold Checker', node, file_path)
 585
 586    def detect_unit_tests(self, node: nodes.Module, file_path: str):
 587        # Check if this is a source file (not already a test file)
 588        is_test_file = ('test' in file_path.lower() or
 589                        node.as_string().lower().startswith('test'))
 590
 591        # Check for model training or evaluation code
 592        ml_operations = [
 593            'fit', 'predict', 'transform', 'score',
 594            'cross_val_score', 'GridSearchCV'
 595        ]
 596        has_ml_operations = any(
 597            op in call.func.as_string()
 598            for call in node.nodes_of_class(nodes.Call)
 599            for op in ml_operations
 600        )
 601
 602        # Check for testing frameworks
 603        test_frameworks = [
 604            'unittest', 'pytest', 'nose',
 605            'TestCase', '@test', '@pytest'
 606        ]
 607        has_tests = any(
 608            framework in node.as_string()
 609            for framework in test_frameworks
 610        )
 611
 612        # Only suggest adding tests for non-test ML files without tests
 613        if not is_test_file and has_ml_operations and not has_tests:
 614            self.add_smell('ScikitLearn', 'Unit Testing Checker', node, file_path)
 615
 616    def detect_data_leakage(self, node: nodes.Module, file_path: str):
 617        # Check for data preprocessing operations
 618        preprocessing_ops = [
 619            'fit_transform', 'transform', 'StandardScaler',
 620            'MinMaxScaler', 'PCA', 'feature_selection'
 621        ]
 622        has_preprocessing = any(
 623            op in call.func.as_string()
 624            for call in node.nodes_of_class(nodes.Call)
 625            for op in preprocessing_ops
 626        )
 627
 628        # Check for model training
 629        has_model_training = any(
 630            'fit' in call.func.as_string()
 631            for call in node.nodes_of_class(nodes.Call)
 632        )
 633
 634        # Check for proper train-test splitting
 635        splitting_methods = [
 636            'train_test_split', 'KFold', 'StratifiedKFold',
 637            'GroupKFold', 'TimeSeriesSplit'
 638        ]
 639        has_proper_split = any(
 640            method in call.func.as_string()
 641            for call in node.nodes_of_class(nodes.Call)
 642            for method in splitting_methods
 643        )
 644
 645        # Only report if preprocessing and training without proper splitting
 646        if (has_preprocessing and has_model_training and not has_proper_split):
 647            self.add_smell('ScikitLearn', 'Data Leakage Checker', node, file_path)
 648
 649    def detect_exception_handling(self, node: nodes.Module, file_path: str):
 650        # Check for risky operations that should have exception handling
 651        risky_operations = [
 652            'fit', 'predict', 'transform', 'inverse_transform',
 653            'cross_val_score', 'GridSearchCV', 'load_', 'dump_',
 654            'pickle', 'joblib'
 655        ]
 656
 657        # Find all risky operation calls
 658        risky_calls = [
 659            call for call in node.nodes_of_class(nodes.Call)
 660            if any(op in call.func.as_string() for op in risky_operations)
 661        ]
 662
 663        # Check if these calls are within try-except blocks
 664        for call in risky_calls:
 665            parent = call.parent
 666            within_try_block = False
 667
 668            while parent and not isinstance(parent, nodes.Module):
 669                if isinstance(parent, nodes.Try):
 670                    within_try_block = True
 671                    break
 672                parent = parent.parent
 673
 674            # Only report if risky operations are not in try-except blocks
 675            if risky_calls and not within_try_block:
 676                self.add_smell('ScikitLearn', 'Exception Handling Checker', call, file_path)
 677
 678    def check_imports(self, node: nodes.Module, file_path: str):
 679        for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
 680            if isinstance(import_node, nodes.Import):
 681                for name, alias in import_node.names:
 682                    if name == 'numpy' and alias != 'np':
 683                        self.add_smell('General', 'Import Checker', import_node, file_path)
 684                    elif name == 'pandas' and alias != 'pd':
 685                        self.add_smell('General', 'Import Checker', import_node, file_path)
 686            elif isinstance(import_node, nodes.ImportFrom):
 687                if import_node.modname == 'numpy' and any(alias != 'np' for _, alias in import_node.names):
 688                    self.add_smell('General', 'Import Checker', import_node, file_path)
 689                elif import_node.modname == 'pandas' and any(alias != 'pd' for _, alias in import_node.names):
 690                    self.add_smell('General', 'Import Checker', import_node, file_path)
 691
 692    # PyTorch Detection Methods
 693    def detect_pytorch_smells(self, node: nodes.Module, file_path: str):
 694        """Detect PyTorch-specific code smells like:
 695        - Missing random seeds
 696        - Not using deterministic mode
 697        - DataLoader randomness issues
 698        - Missing masks for numerical operations
 699        - Direct forward() calls
 700        - Missing gradient zeroing
 701        - Missing batch normalization
 702        - Missing dropout
 703        - Missing data augmentation
 704        - Missing learning rate scheduling
 705        - Missing logging
 706        - Missing model evaluation mode
 707
 708        Args:
 709            node: AST node representing the Python module
 710            file_path: Path to the file being analyzed
 711        """
 712        self.detect_pytorch_random_seed(node, file_path)
 713        self.detect_pytorch_deterministic(node, file_path)
 714        self.detect_pytorch_dataloader_random(node, file_path)
 715        self.detect_pytorch_mask(node, file_path)
 716        self.detect_pytorch_forward(node, file_path)
 717        self.detect_pytorch_grad_zero(node, file_path)
 718        self.detect_pytorch_batch_norm(node, file_path)
 719        self.detect_pytorch_dropout(node, file_path)
 720        self.detect_pytorch_augmentation(node, file_path)
 721        self.detect_pytorch_lr_scheduler(node, file_path)
 722        self.detect_pytorch_logging(node, file_path)
 723        self.detect_pytorch_eval_mode(node, file_path)
 724
 725    def detect_pytorch_random_seed(self, node: nodes.Module, file_path: str):
 726        # Check if there are any random operations that need seeding
 727        random_operations = [
 728            'torch.rand', 'torch.randn', 'torch.randint',
 729            'torch.randperm', 'torch.bernoulli', 'torch.normal',
 730            'torch.dropout', 'torch.nn.Dropout'
 731        ]
 732
 733        has_random_ops = any(
 734            op in call.func.as_string()
 735            for call in node.nodes_of_class(nodes.Call)
 736            for op in random_operations
 737        )
 738
 739        has_manual_seed = any(
 740            'torch.manual_seed' in call.func.as_string() or
 741            'torch.cuda.manual_seed' in call.func.as_string() or
 742            'torch.cuda.manual_seed_all' in call.func.as_string()
 743            for call in node.nodes_of_class(nodes.Call)
 744        )
 745
 746        # Only report if random operations are used without seeding
 747        if has_random_ops and not has_manual_seed:
 748            self.add_smell('PyTorch', 'Randomness Control Checker', node, file_path)
 749
 750    def detect_pytorch_deterministic(self, node: nodes.Module, file_path: str):
 751        # Check for operations that benefit from deterministic algorithms
 752        deterministic_sensitive_ops = [
 753            'torch.nn.Conv', 'torch.nn.LSTM', 'torch.nn.GRU',
 754            'torch.backends.cudnn', 'torch.cuda', 'DataLoader'
 755        ]
 756
 757        has_sensitive_ops = any(
 758            op in node.as_string()
 759            for op in deterministic_sensitive_ops
 760        )
 761
 762        has_deterministic_setting = any(
 763            'torch.use_deterministic_algorithms' in call.func.as_string() or
 764            'torch.backends.cudnn.deterministic' in call.as_string()
 765            for call in node.nodes_of_class(nodes.Call)
 766        )
 767
 768        # Only report if sensitive operations are used without deterministic setting
 769        if has_sensitive_ops and not has_deterministic_setting:
 770            self.add_smell('PyTorch', 'Deterministic Algorithm Usage Checker', node, file_path)
 771
 772    def detect_pytorch_dataloader_random(self, node: nodes.Module, file_path: str):
 773        for call in node.nodes_of_class(nodes.Call):
 774            if 'DataLoader' in call.func.as_string():
 775                # Check if it's actually a PyTorch DataLoader
 776                is_pytorch_dataloader = any(
 777                    'torch' in imp.modname
 778                    for imp in node.nodes_of_class(nodes.ImportFrom)
 779                )
 780
 781                # Check if shuffling is enabled
 782                has_shuffle = any(
 783                    (kw.arg == 'shuffle' and getattr(kw.value, 'value', False) is True)
 784                    for kw in call.keywords
 785                )
 786
 787                # Check for proper random state control
 788                has_random_control = any(
 789                    (kw.arg in ['worker_init_fn', 'generator'])
 790                    for kw in call.keywords
 791                )
 792
 793                # Only report if it's a PyTorch DataLoader with shuffling but no random control
 794                if is_pytorch_dataloader and has_shuffle and not has_random_control:
 795                    self.add_smell('PyTorch', 'Randomness Control Checker (PyTorch-Dataloader)', call, file_path)
 796
 797    def detect_pytorch_mask(self, node: nodes.Module, file_path: str):
 798        for call in node.nodes_of_class(nodes.Call):
 799            if 'torch.log' in call.func.as_string():
 800                # Check if the input might contain zeros or negative values
 801                input_arg = call.args[0] if call.args else None
 802                if input_arg:
 803                    # Look for potential risky operations before the log
 804                    risky_ops = ['zeros', 'randn', 'rand', 'sub', 'subtract']
 805                    has_risky_input = any(
 806                        op in input_arg.as_string()
 807                        for op in risky_ops
 808                    )
 809
 810                    # Check if masking is applied
 811                    has_mask = (
 812                        len(call.args) > 1 or
 813                        any('clamp' in node.as_string() or
 814                            'mask' in node.as_string() or
 815                            'where' in node.as_string())
 816                    )
 817
 818                    # Only report if there's a risk of invalid input without masking
 819                    if has_risky_input and not has_mask:
 820                        self.add_smell('PyTorch', 'Mask Missing Checker', call, file_path)
 821
 822    def detect_pytorch_forward(self, node: nodes.Module, file_path: str):
 823        for call in node.nodes_of_class(nodes.Call):
 824            if 'forward' in call.func.as_string():
 825                # Check if it's a direct forward call on a neural network
 826                is_nn_forward = (
 827                    isinstance(call.func, nodes.Attribute) and
 828                    'forward' in call.func.attrname and
 829                    any(parent_cls.bases and
 830                        'nn.Module' in parent_cls.bases[0].as_string()
 831                        for parent_cls in node.nodes_of_class(nodes.ClassDef))
 832                )
 833
 834                # Check if it's called directly instead of using __call__
 835                is_direct_call = (
 836                    'net.forward' in call.func.as_string() or
 837                    'model.forward' in call.func.as_string()
 838                )
 839
 840                # Only report if it's a direct forward call on an nn.Module
 841                if is_nn_forward and is_direct_call:
 842                    self.add_smell('PyTorch', 'Net Forward Checker', call, file_path)
 843
 844    def detect_pytorch_grad_zero(self, node: nodes.Module, file_path: str):
 845        # Check if there's actual training happening
 846        has_training_loop = any(
 847            'loss.backward' in call.func.as_string() or
 848            'backward' in call.func.as_string()
 849            for call in node.nodes_of_class(nodes.Call)
 850        )
 851
 852        has_optimizer = any(
 853            'optim.' in call.func.as_string()
 854            for call in node.nodes_of_class(nodes.Call)
 855        )
 856
 857        has_grad_zero = any(
 858            'zero_grad' in call.func.as_string() or
 859            'set_to_none' in call.func.as_string()  # Alternative method
 860            for call in node.nodes_of_class(nodes.Call)
 861        )
 862
 863        # Only report if there's training without gradient zeroing
 864        if has_training_loop and has_optimizer and not has_grad_zero:
 865            self.add_smell('PyTorch', 'Gradient Clear Checker', node, file_path)
 866
 867    def detect_pytorch_batch_norm(self, node: nodes.Module, file_path: str):
 868        # Check if it's a CNN or deep network that could benefit from BatchNorm
 869        conv_layers = [
 870            'Conv1d', 'Conv2d', 'Conv3d',
 871            'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d'
 872        ]
 873
 874        has_conv_layers = any(
 875            layer in call.func.as_string()
 876            for call in node.nodes_of_class(nodes.Call)
 877            for layer in conv_layers
 878        )
 879
 880        has_deep_structure = len([
 881            call for call in node.nodes_of_class(nodes.Call)
 882            if any(layer in call.func.as_string() for layer in ['Linear', 'Conv'])
 883        ]) > 2
 884
 885        has_batch_norm = any(
 886            'BatchNorm' in call.func.as_string()
 887            for call in node.nodes_of_class(nodes.Call)
 888        )
 889
 890        # Only suggest BatchNorm for appropriate architectures
 891        if (has_conv_layers or has_deep_structure) and not has_batch_norm:
 892            self.add_smell('PyTorch', 'Batch Normalisation Checker', node, file_path)
 893
 894    def detect_pytorch_dropout(self, node: nodes.Module, file_path: str):
 895        # Check if the model is complex enough to benefit from dropout
 896        has_multiple_layers = len([
 897            call for call in node.nodes_of_class(nodes.Call)
 898            if any(layer in call.func.as_string()
 899                   for layer in ['Linear', 'Conv', 'LSTM', 'GRU'])
 900        ]) > 2
 901
 902        has_training_code = any(
 903            'train' in call.func.as_string() or
 904            'backward' in call.func.as_string()
 905            for call in node.nodes_of_class(nodes.Call)
 906        )
 907
 908        has_dropout = any(
 909            'Dropout' in call.func.as_string()
 910            for call in node.nodes_of_class(nodes.Call)
 911        )
 912
 913        # Only suggest dropout for complex models during training
 914        if has_multiple_layers and has_training_code and not has_dropout:
 915            self.add_smell('PyTorch', 'Dropout Usage Checker', node, file_path)
 916
 917    def detect_pytorch_augmentation(self, node: nodes.Module, file_path: str):
 918        # Check if it's a computer vision task
 919        vision_indicators = [
 920            'ImageFolder', 'Dataset', 'DataLoader',
 921            'Conv2d', 'Conv3d', 'ResNet', 'VGG',
 922            'image', 'img', 'PIL'
 923        ]
 924
 925        is_vision_task = any(
 926            indicator in node.as_string()
 927            for indicator in vision_indicators
 928        )
 929
 930        has_training = any(
 931            'train' in call.func.as_string() or
 932            'fit' in call.func.as_string()
 933            for call in node.nodes_of_class(nodes.Call)
 934        )
 935
 936        has_augmentation = any(
 937            'transforms' in call.func.as_string() or
 938            'augment' in call.func.as_string().lower()
 939            for call in node.nodes_of_class(nodes.Call)
 940        )
 941
 942        # Only suggest augmentation for vision tasks during training
 943        if is_vision_task and has_training and not has_augmentation:
 944            self.add_smell('PyTorch', 'Data Augmentation Checker', node, file_path)
 945
 946    def detect_pytorch_lr_scheduler(self, node: nodes.Module, file_path: str):
 947        # Check if there's a training loop with enough epochs
 948        has_epochs = False
 949        for node in node.nodes_of_class(nodes.For):
 950            if 'epoch' in node.as_string().lower():
 951                try:
 952                    # Try to determine the number of epochs
 953                    if hasattr(node.iter, 'args') and len(node.iter.args) > 0:
 954                        epoch_num = int(node.iter.args[0].value)
 955                        has_epochs = epoch_num > 10  # Suggest scheduler for longer training
 956                except (AttributeError, ValueError):
 957                    continue
 958
 959        has_optimizer = any(
 960            'optim.' in call.func.as_string()
 961            for call in node.nodes_of_class(nodes.Call)
 962        )
 963
 964        has_scheduler = any(
 965            'lr_scheduler' in call.func.as_string() or
 966            'LRScheduler' in call.func.as_string()
 967            for call in node.nodes_of_class(nodes.Call)
 968        )
 969
 970        # Only suggest scheduler for long training processes
 971        if has_epochs and has_optimizer and not has_scheduler:
 972            self.add_smell('PyTorch', 'Learning Rate Scheduler Checker', node, file_path)
 973
 974    def detect_pytorch_logging(self, node: nodes.Module, file_path: str):
 975        # Check if there's actual training to log
 976        has_training_loop = any(
 977            'train' in call.func.as_string() or
 978            'backward' in call.func.as_string()
 979            for call in node.nodes_of_class(nodes.Call)
 980        )
 981
 982        has_metrics = any(
 983            metric in node.as_string().lower()
 984            for metric in ['loss', 'accuracy', 'score', 'metric']
 985        )
 986
 987        has_logging = any(
 988            logger in node.as_string()
 989            for logger in [
 990                'tensorboard', 'SummaryWriter', 'wandb',
 991                'MLflow', 'Neptune', 'logger'
 992            ]
 993        )
 994
 995        # Only suggest logging for training with metrics
 996        if has_training_loop and has_metrics and not has_logging:
 997            self.add_smell('PyTorch', 'Logging Checker', node, file_path)
 998
 999    def detect_pytorch_eval_mode(self, node: nodes.Module, file_path: str):
1000        # Check if there's validation/testing happening
1001        evaluation_indicators = [
1002            'val_loader', 'test_loader', 'validate',
1003            'evaluation', 'testing', 'predict'
1004        ]
1005
1006        has_evaluation = any(
1007            indicator in node.as_string().lower()
1008            for indicator in evaluation_indicators
1009        )
1010
1011        has_model_usage = any(
1012            'forward' in call.func.as_string() or
1013            'model(' in call.as_string()
1014            for call in node.nodes_of_class(nodes.Call)
1015        )
1016
1017        has_eval_mode = any(
1018            'eval' in call.func.as_string() or
1019            'train(False)' in call.func.as_string()
1020            for call in node.nodes_of_class(nodes.Call)
1021        )
1022
1023        # Only suggest eval mode when doing validation/testing
1024        if has_evaluation and has_model_usage and not has_eval_mode:
1025            self.add_smell('PyTorch', 'Model Evaluation Checker', node, file_path)
1026
1027    # TensorFlow Detection Methods
1028    def detect_tensorflow_smells(self, node: nodes.Module, file_path: str):
1029        """Detect TensorFlow-specific code smells like:
1030        - Missing random seeds
1031        - Missing early stopping
1032        - Missing checkpointing
1033        - Memory leaks
1034        - Missing masks
1035        - Python lists instead of TensorArrays
1036        - Missing threshold-independent metrics
1037        - Missing logging
1038        - Missing batch normalization
1039        - Missing dropout
1040        - Missing data augmentation
1041        - Missing learning rate scheduling
1042        - Missing model evaluation
1043
1044        Args:
1045            node: AST node representing the Python module
1046            file_path: Path to the file being analyzed
1047        """
1048        self.detect_tf_random_seed(node, file_path)
1049        self.detect_tf_early_stopping(node, file_path)
1050        self.detect_tf_checkpointing(node, file_path)
1051        self.detect_tf_memory_release(node, file_path)
1052        self.detect_tf_mask(node, file_path)
1053        self.detect_tf_tensor_array(node, file_path)
1054        self.detect_tf_metrics(node, file_path)
1055        self.detect_tf_logging(node, file_path)
1056        self.detect_tf_batch_norm(node, file_path)
1057        self.detect_tf_dropout(node, file_path)
1058        self.detect_tf_augmentation(node, file_path)
1059        self.detect_tf_lr_scheduler(node, file_path)
1060        self.detect_tf_model_evaluation(node, file_path)
1061
1062    def detect_tf_random_seed(self, node: nodes.Module, file_path: str):
1063        # Check for operations that need random seed control
1064        random_ops = [
1065            'random.normal', 'random.uniform', 'random.shuffle',
1066            'dropout', 'RandomRotation', 'RandomFlip', 'RandomZoom'
1067        ]
1068
1069        has_random_ops = any(
1070            op in call.func.as_string()
1071            for call in node.nodes_of_class(nodes.Call)
1072            for op in random_ops
1073        )
1074
1075        has_seed_set = any(
1076            'random.set_seed' in call.func.as_string() or
1077            'set_random_seed' in call.func.as_string()
1078            for call in node.nodes_of_class(nodes.Call)
1079        )
1080
1081        # Only report if random operations are used without seed
1082        if has_random_ops and not has_seed_set:
1083            self.add_smell('TensorFlow', 'Randomness Control Checker', node, file_path)
1084
1085    def detect_tf_early_stopping(self, node: nodes.Module, file_path: str):
1086        # Check if there's actual model training
1087        has_training = any(
1088            'model.fit' in call.func.as_string() or
1089            'fit(' in call.func.as_string()
1090            for call in node.nodes_of_class(nodes.Call)
1091        )
1092
1093        # Check for training loop with multiple epochs
1094        has_multiple_epochs = any(
1095            'epochs' in kw.arg and getattr(kw.value, 'value', 1) > 1
1096            for call in node.nodes_of_class(nodes.Call)
1097            for kw in call.keywords
1098        )
1099
1100        has_early_stopping = any(
1101            'EarlyStopping' in call.func.as_string()
1102            for call in node.nodes_of_class(nodes.Call)
1103        )
1104
1105        # Only suggest early stopping for actual training with multiple epochs
1106        if has_training and has_multiple_epochs and not has_early_stopping:
1107            self.add_smell('TensorFlow', 'Early Stopping Checker', node, file_path)
1108
1109    def detect_tf_checkpointing(self, node: nodes.Module, file_path: str):
1110        # Check for model training and complexity
1111        has_training = any(
1112            'model.fit' in call.func.as_string()
1113            for call in node.nodes_of_class(nodes.Call)
1114        )
1115
1116        has_complex_model = any(
1117            layer in node.as_string()
1118            for layer in ['Dense', 'Conv', 'LSTM', 'GRU']
1119        )
1120
1121        has_checkpointing = any(
1122            'ModelCheckpoint' in call.func.as_string() or
1123            'save_weights' in call.func.as_string()
1124            for call in node.nodes_of_class(nodes.Call)
1125        )
1126
1127        # Only suggest checkpointing for complex models during training
1128        if has_training and has_complex_model and not has_checkpointing:
1129            self.add_smell('TensorFlow', 'Checkpointing Checker', node, file_path)
1130
1131    def detect_tf_memory_release(self, node: nodes.Module, file_path: str):
1132        # Check for memory-intensive operations
1133        memory_intensive_ops = [
1134            'model.fit', 'predict', 'evaluate',
1135            'Conv', 'LSTM', 'GRU', 'Attention'
1136        ]
1137
1138        has_intensive_ops = any(
1139            op in call.func.as_string()
1140            for call in node.nodes_of_class(nodes.Call)
1141            for op in memory_intensive_ops
1142        )
1143
1144        has_memory_release = any(
1145            'clear_session' in call.func.as_string() or
1146            'reset_states' in call.func.as_string()
1147            for call in node.nodes_of_class(nodes.Call)
1148        )
1149
1150        # Only suggest memory release for memory-intensive operations
1151        if has_intensive_ops and not has_memory_release:
1152            self.add_smell('TensorFlow', 'Memory Release Checker', node, file_path)
1153
1154    def detect_tf_mask(self, node: nodes.Module, file_path: str):
1155        for call in node.nodes_of_class(nodes.Call):
1156            if hasattr(call.func, 'as_string') and 'tf.math.log' in call.func.as_string():
1157                # Check for potential zero/negative inputs
1158                input_arg = call.args[0] if call.args else None
1159                if input_arg and hasattr(input_arg, 'as_string'):  # Add null check
1160                    risky_ops = ['zeros', 'random', 'subtract', 'sub']
1161                    has_risky_input = any(
1162                        op in input_arg.as_string()
1163                        for op in risky_ops
1164                    )
1165
1166                    has_mask = (
1167                        len(call.args) > 1 or
1168                        any(mask_op in node.as_string()
1169                            for mask_op in ['where', 'clip', 'maximum'])
1170                    )
1171
1172                    # Only report if there's risk of invalid input without masking
1173                    if has_risky_input and not has_mask:
1174                        self.add_smell('TensorFlow', 'Mask Missing Checker', call, file_path)
1175
1176    def detect_tf_tensor_array(self, node: nodes.Module, file_path: str):
1177        # Check for dynamic sequence operations
1178        dynamic_ops = [
1179            'RNN', 'LSTM', 'GRU', 'while_loop',
1180            'map_fn', 'scan', 'dynamic'
1181        ]
1182
1183        has_dynamic_ops = any(
1184            op in node.as_string()
1185            for op in dynamic_ops
1186        )
1187
1188        using_python_list = any(
1189            'append' in call.func.as_string() or
1190            'extend' in call.func.as_string()
1191            for call in node.nodes_of_class(nodes.Call)
1192        )
1193
1194        has_tensor_array = any(
1195            'TensorArray' in call.func.as_string()
1196            for call in node.nodes_of_class(nodes.Call)
1197        )
1198
1199        # Only suggest TensorArray for dynamic operations using Python lists
1200        if has_dynamic_ops and using_python_list and not has_tensor_array:
1201            self.add_smell('TensorFlow', 'Tensor Array Checker', node, file_path)
1202
1203    def detect_tf_metrics(self, node: nodes.Module, file_path: str):
1204        # Check if it's a classification task
1205        classification_indicators = [
1206            'Binary', 'Categorical', 'sparse_categorical',
1207            'accuracy', 'precision', 'recall', 'f1'
1208        ]
1209
1210        is_classification = any(
1211            indicator in node.as_string()
1212            for indicator in classification_indicators
1213        )
1214
1215        has_basic_metrics = any(
1216            metric in call.func.as_string()
1217            for call in node.nodes_of_class(nodes.Call)
1218            for metric in ['accuracy', 'precision', 'recall']
1219        )
1220
1221        has_threshold_independent = any(
1222            metric in call.func.as_string()
1223            for call in node.nodes_of_class(nodes.Call)
1224            for metric in ['AUC', 'AUROCScore', 'PrecisionRecallCurve']
1225        )
1226
1227        # Only suggest threshold-independent metrics for classification tasks
1228        if is_classification and has_basic_metrics and not has_threshold_independent:
1229            self.add_smell('TensorFlow', 'Dependent Threshold Checker', node, file_path)
1230
1231    def detect_tf_logging(self, node: nodes.Module, file_path: str):
1232        # Check if there's actual training to log
1233        has_training_loop = any(
1234            'model.fit' in call.func.as_string() or
1235            'train' in call.func.as_string()
1236            for call in node.nodes_of_class(nodes.Call)
1237        )
1238
1239        has_metrics = any(
1240            metric in node.as_string().lower()
1241            for metric in ['loss', 'accuracy', 'score', 'metric']
1242        )
1243
1244        has_logging = any(
1245            logger in node.as_string()
1246            for logger in [
1247                'tf.summary', 'TensorBoard', 'CSVLogger',
1248                'WandbCallback', 'MLflow'
1249            ]
1250        )
1251
1252        # Only suggest logging for training with metrics
1253        if has_training_loop and has_metrics and not has_logging:
1254            self.add_smell('TensorFlow', 'Logging Checker', node, file_path)
1255
1256    def detect_tf_batch_norm(self, node: nodes.Module, file_path: str):
1257        # Check if it's a deep network that could benefit from BatchNorm
1258        conv_layers = [
1259            'Conv1D', 'Conv2D', 'Conv3D',
1260            'Dense', 'SeparableConv'
1261        ]
1262
1263        has_conv_layers = any(
1264            layer in call.func.as_string()
1265            for call in node.nodes_of_class(nodes.Call)
1266            for layer in conv_layers
1267        )
1268
1269        # Check if model is deep enough
1270        layer_count = len([
1271            call for call in node.nodes_of_class(nodes.Call)
1272            if any(layer in call.func.as_string()
1273                   for layer in conv_layers)
1274        ])
1275
1276        has_batch_norm = any(
1277            'BatchNormalization' in call.func.as_string()
1278            for call in node.nodes_of_class(nodes.Call)
1279        )
1280
1281        # Only suggest BatchNorm for deep networks with conv layers
1282        if has_conv_layers and layer_count > 2 and not has_batch_norm:
1283            self.add_smell('TensorFlow', 'Batch Normalisation Checker', node, file_path)
1284
1285    def detect_tf_dropout(self, node: nodes.Module, file_path: str):
1286        # Check if model is complex enough to need dropout
1287        deep_layers = ['Dense', 'Conv', 'LSTM', 'GRU']
1288
1289        layer_count = len([
1290            call for call in node.nodes_of_class(nodes.Call)
1291            if any(layer in call.func.as_string()
1292                   for layer in deep_layers)
1293        ])
1294
1295        has_training = any(
1296            'model.fit' in call.func.as_string() or
1297            'training=True' in call.as_string()
1298            for call in node.nodes_of_class(nodes.Call)
1299        )
1300
1301        has_dropout = any(
1302            'Dropout' in call.func.as_string()
1303            for call in node.nodes_of_class(nodes.Call)
1304        )
1305
1306        # Only suggest dropout for complex models during training
1307        if layer_count > 2 and has_training and not has_dropout:
1308            self.add_smell('TensorFlow', 'Dropout Usage Checker', node, file_path)
1309
1310    def detect_tf_augmentation(self, node: nodes.Module, file_path: str):
1311        # Check if it's a computer vision task
1312        vision_indicators = [
1313            'image', 'img', 'Conv2D', 'Conv3D',
1314            'ImageDataGenerator', 'load_img'
1315        ]
1316
1317        is_vision_task = any(
1318            indicator in node.as_string()
1319            for indicator in vision_indicators
1320        )
1321
1322        has_training = any(
1323            'model.fit' in call.func.as_string() or
1324            'train' in call.func.as_string().lower()
1325            for call in node.nodes_of_class(nodes.Call)
1326        )
1327
1328        has_augmentation = any(
1329            aug in call.func.as_string()
1330            for call in node.nodes_of_class(nodes.Call)
1331            for aug in ['ImageDataGenerator', 'RandomFlip', 'RandomRotation', 'RandomZoom']
1332        )
1333
1334        # Only suggest augmentation for vision tasks during training
1335        if is_vision_task and has_training and not has_augmentation:
1336            self.add_smell('TensorFlow', 'Data Augmentation Checker', node, file_path)
1337
1338    def detect_tf_lr_scheduler(self, node: nodes.Module, file_path: str):
1339        # Check if there's a long training process
1340        has_epochs = False
1341        for node in node.nodes_of_class(nodes.For):
1342            if 'epoch' in node.as_string().lower():
1343                try:
1344                    if hasattr(node.iter, 'args') and len(node.iter.args) > 0:
1345                        epoch_num = int(node.iter.args[0].value)
1346                        has_epochs = epoch_num > 5  # Suggest scheduler for longer training
1347                except (AttributeError, ValueError):
1348                    continue
1349
1350        has_optimizer = any(
1351            'optimizer' in call.func.as_string().lower()
1352            for call in node.nodes_of_class(nodes.Call)
1353        )
1354
1355        has_scheduler = any(
1356            'LearningRateScheduler' in call.func.as_string() or
1357            'schedules' in call.func.as_string() or
1358            'ReduceLROnPlateau' in call.func.as_string()
1359            for call in node.nodes_of_class(nodes.Call)
1360        )
1361
1362        # Only suggest scheduler for long training processes
1363        if has_epochs and has_optimizer and not has_scheduler:
1364            self.add_smell('TensorFlow', 'Learning Rate Scheduler Checker', node, file_path)
1365
1366    def detect_tf_model_evaluation(self, node: nodes.Module, file_path: str):
1367        # Check if there's a model to evaluate
1368        has_model = any(
1369            'model' in call.func.as_string() or
1370            'Sequential' in call.func.as_string() or
1371            'Model(' in call.func.as_string()
1372            for call in node.nodes_of_class(nodes.Call)
1373        )
1374
1375        has_test_data = any(
1376            data in node.as_string().lower()
1377            for data in ['test', 'val', 'valid', 'evaluation']
1378        )
1379
1380        has_evaluation = any(
1381            'evaluate' in call.func.as_string() or
1382            'predict' in call.func.as_string()
1383            for call in node.nodes_of_class(nodes.Call)
1384        )
1385
1386        # Only suggest evaluation when there's a model and test data
1387        if has_model and has_test_data and not has_evaluation:
1388            self.add_smell('TensorFlow', 'Model Evaluation Checker', node, file_path)
1389
1390    def generate_report(self) -> str:
1391        """Generate a detailed text report of all detected code smells.
1392
1393        Returns:
1394            A formatted string containing the analysis report with:
1395            - Framework-specific smell counts
1396            - Details for each smell including location and fixes
1397            - Total number of smells detected
1398        """
1399        report = "Framework-Specific Code Smell Report\n====================================\n\n"
1400        smell_counts = {}
1401        for smell in self.smells:
1402            framework = smell['framework']
1403            if framework not in smell_counts:
1404                smell_counts[framework] = {}
1405            if smell['name'] not in smell_counts[framework]:
1406                smell_counts[framework][smell['name']] = 0
1407            smell_counts[framework][smell['name']] += 1
1408
1409            report += f"Framework: {framework}\n"
1410            report += f"Smell: {smell['name']}\n"
1411            report += f"File: {smell['file_path']}\n"
1412
1413            if smell['line_number'] != 0:
1414                report += f"Line: {smell['line_number']}\n"
1415
1416            code_lines = smell['code_snippet'].strip().split('\n')
1417            if len(code_lines) <= 3:
1418                report += f"Code Snippet:\n{smell['code_snippet']}\n"
1419
1420            report += f"How to Fix: {smell['how_to_fix']}\n"
1421            report += f"Benefits: {smell['benefits']}\n"
1422            report += f"Strategies: {smell['strategies']}\n\n"
1423
1424        report += "Smell Counts:\n"
1425        for framework, counts in smell_counts.items():
1426            report += f"{framework}:\n"
1427            for smell, count in counts.items():
1428                report += f"  {smell}: {count}\n"
1429        report += f"\nTotal smells detected: {len(self.smells)}"
1430        return report
1431
1432    def get_results(self) -> List[Dict[str, str]]:
1433        """Get a simplified list of detected code smells.
1434
1435        Returns:
1436            List of dictionaries containing:
1437            - framework: The ML framework
1438            - name: Name of the smell
1439            - fix: How to fix it
1440            - benefits: Benefits of fixing
1441            - location: Where it was found
1442        """
1443        return [
1444            {
1445                'framework': smell['framework'],
1446                'name': smell['name'],
1447                'fix': smell['how_to_fix'],
1448                'benefits': smell['benefits'],
1449                'location': f"Line {smell['line_number']}" if smell['line_number'] != 0 else ""
1450            }
1451            for smell in self.smells
1452        ]
1453
1454    def get_smells(self) -> Dict[str, List[Dict[str, str]]]:
1455        return {"General": [{"name": "Import Checker",
1456                             "how_to_fix": "Use standard naming conventions for imported modules.",
1457                             "benefits": "Improves code readability and maintainability.",
1458                             "strategies": "Follow standard naming conventions (e.g., import numpy as np, import pandas as pd)."}],
1459                "Pandas": [{"name": "Unnecessary Iteration",
1460                            "how_to_fix": "Use vectorized operations instead of loops.",
1461                            "benefits": "Enhances performance and reduces execution time.",
1462                            "strategies": "Replace loops with Pandas vectorized functions (e.g., apply, map, vectorized arithmetic operations)."},
1463                           {"name": "DataFrame Iteration Modification",
1464                            "how_to_fix": "Avoid modifying DataFrame during iteration.",
1465                            "benefits": "Prevents unexpected behaviour and potential data corruption.",
1466                            "strategies": "Use temporary variables or vectorized operations for modifications."},
1467                           {"name": "Chain Indexing",
1468                            "how_to_fix": "Use single indexing or .loc[], .iloc[] methods.",
1469                            "benefits": "Enhances code readability and prevents performance issues.",
1470                            "strategies": "Use .loc[] or .iloc[] for DataFrame indexing instead of chained indexing."},
1471                           {"name": "Datatype Checker",
1472                            "how_to_fix": "Set data types explicitly when importing data.",
1473                            "benefits": "Ensures correct data format and reduces memory usage.",
1474                            "strategies": "Use dtype parameter in Pandas read functions (e.g., pd.read_csv)."},
1475                           {"name": "Column Selection Checker",
1476                            "how_to_fix": "Select necessary columns after importing DataFrame.",
1477                            "benefits": "Clarifies data usage and improves performance.",
1478                            "strategies": "Use column selection methods (e.g., df[['col1', 'col2']])."},
1479                           {"name": "Merge Parameter Checker",
1480                            "how_to_fix": "Specify how, on, and validate parameters in merge operations.",
1481                            "benefits": "Ensures accurate data merging and prevents data loss.",
1482                            "strategies": "Use appropriate parameters in pd.merge function."},
1483                           {"name": "InPlace Checker",
1484                            "how_to_fix": "Assign operations to a new DataFrame variable.",
1485                            "benefits": "Prevents data loss and improves code clarity.",
1486                            "strategies": "Assign operation results to a new variable instead of using inplace=True."},
1487                           {"name": "DataFrame Conversion Checker",
1488                            "how_to_fix": "Use .to_numpy() instead of .values.",
1489                            "benefits": "Ensures future compatibility and avoids unexpected behaviour.",
1490                            "strategies": "Replace .values with .to_numpy() in Pandas DataFrame conversion."}],
1491                "NumPy": [{"name": "NaN Equality Checker",
1492                           "how_to_fix": "Use np.isnan() to check for NaN values.",
1493                           "benefits": "Ensures accurate data handling and avoids logical errors.",
1494                           "strategies": "Replace == np.nan with np.isnan()."},
1495                          {"name": "Randomness Control Checker",
1496                           "how_to_fix": "Use np.random.seed() for reproducibility.",
1497                           "benefits": "Enables reproducible results and debugging.",
1498                           "strategies": "Set random seed using np.random.seed(seed_value)."},
1499                          {"name": "Array Creation Efficiency",
1500                           "how_to_fix": "Specify dtype when creating arrays.",
1501                           "benefits": "Improves memory efficiency and avoids unnecessary conversions.",
1502                           "strategies": "Use dtype parameter in np.array, np.zeros, np.ones, np.empty functions."},
1503                          {"name": "Inefficient Operations",
1504                           "how_to_fix": "Optimize numerical operations.",
1505                           "benefits": "Improves performance and reduces execution time.",
1506                           "strategies": "Use vectorized functions instead of loops for element-wise operations."},
1507                          {"name": "Dtype Consistency",
1508                           "how_to_fix": "Ensure consistent data types in operations.",
1509                           "benefits": "Improves accuracy and avoids logical errors.",
1510                           "strategies": "Use consistent data types in np.sum, np.mean, np.max, np.min functions."},
1511                          {"name": "Broadcasting Risk",
1512                           "how_to_fix": "Specify axis in operations between arrays.",
1513                           "benefits": "Improves performance and avoids broadcasting issues.",
1514                           "strategies": "Use axis parameter in np.reshape, np.transpose functions."},
1515                          {"name": "Copy-View Confusion",
1516                           "how_to_fix": "Explicitly copy arrays before slicing.",
1517                           "benefits": "Improves performance and avoids unexpected behaviour.",
1518                           "strategies": "Use .copy() method or np.copy function before slicing."},
1519                          {"name": "Missing Axis Specification",
1520                           "how_to_fix": "Specify axis in array operations.",
1521                           "benefits": "Improves performance and avoids unexpected behaviour.",
1522                           "strategies": "Use axis parameter in np.sum, np.mean, np.max, np.min, np.argmax, np.argmin, np.any, np.all functions."}],
1523                "ScikitLearn": [{"name": "Scaler Missing Checker",
1524                                 "how_to_fix": "Apply scaling before scaling-sensitive operations.",
1525                                 "benefits": "Improves model performance and accuracy.",
1526                                 "strategies": "Use StandardScaler, MinMaxScaler, etc., before applying PCA, SVM, etc."},
1527                                {"name": "Pipeline Checker",
1528                                 "how_to_fix": "Use Pipelines for all scikit-learn estimators.",
1529                                 "benefits": "Prevents data leakage and ensures correct model evaluation.",
1530                                 "strategies": "Implement Pipeline from sklearn.pipeline for preprocessing and model fitting."},
1531                                {"name": "Cross Validation Checker",
1532                                 "how_to_fix": "Implement cross-validation techniques for robust model evaluation.",
1533                                 "benefits": "Enhances model performance and reduces overfitting.",
1534                                 "strategies": "Use cross_val_score, KFold, or other cross-validation methods."},
1535                                {"name": "Randomness Control Checker",
1536                                 "how_to_fix": "Set random_state in estimators for reproducibility.",
1537                                 "benefits": "Ensures reproducible results and consistent model behaviour.",
1538                                 "strategies": "Set random_state to a fixed value in scikit-learn estimators."},
1539                                {"name": "Verbose Mode Checker",
1540                                 "how_to_fix": "Enable verbose mode for long training processes.",
1541                                 "benefits": "Provides better insights into model training progress.",
1542                                 "strategies": "Set verbose=True in scikit-learn estimators with long training times."},
1543                                {"name": "Dependent Threshold Checker",
1544                                 "how_to_fix": "Use threshold-independent metrics alongside threshold-dependent ones.",
1545                                 "benefits": "Provides a comprehensive evaluation of model performance.",
1546                                 "strategies": "Include metrics like ROC AUC score alongside accuracy or F1-score."},
1547                                {"name": "Unit Testing Checker",
1548                                 "how_to_fix": "Write unit tests for data processing and model components.",
1549                                 "benefits": "Ensures code reliability and prevents bugs.",
1550                                 "strategies": "Use unittest or pytest to write and run tests for individual components."},
1551                                {"name": "Data Leakage Checker",
1552                                 "how_to_fix": "Split data before any preprocessing or feature engineering.",
1553                                 "benefits": "Prevents data leakage and ensures valid model evaluation.",
1554                                 "strategies": "Use train_test_split before any data preprocessing steps."},
1555                                {"name": "Exception Handling Checker",
1556                                 "how_to_fix": "Implement proper exception handling for model operations.",
1557                                 "benefits": "Improves code robustness and error reporting.",
1558                                 "strategies": "Use try-except blocks to handle potential exceptions in model training and prediction."}],
1559                "PyTorch": [{"name": "Randomness Control Checker",
1560                             "how_to_fix": "Set random seed using torch.manual_seed() for reproducibility.",
1561                             "benefits": "Ensures reproducible results across different runs.",
1562                             "strategies": "Add torch.manual_seed(seed_value) at the start of your script."},
1563                            {"name": "Deterministic Algorithm Usage Checker",
1564                             "how_to_fix": "Enable deterministic algorithms using torch.use_deterministic_algorithms(True).",
1565                             "benefits": "Ensures consistent results across different hardware and runs.",
1566                             "strategies": "Set torch.use_deterministic_algorithms(True) and handle any required adjustments."},
1567                            {"name": "Randomness Control Checker (PyTorch-Dataloader)",
1568                             "how_to_fix": "Set worker_init_fn and generator in DataLoader for reproducible data loading.",
1569                             "benefits": "Ensures consistent data loading across different runs.",
1570                             "strategies": "Configure worker_init_fn and generator parameters in DataLoader initialization."},
1571                            {"name": "Mask Missing Checker",
1572                             "how_to_fix": "Use appropriate masking when dealing with log operations.",
1573                             "benefits": "Prevents numerical errors and improves model stability.",
1574                             "strategies": "Apply masks before log operations to handle invalid values."},
1575                            {"name": "Net Forward Checker",
1576                             "how_to_fix": "Use model(input) instead of model.forward(input).",
1577                             "benefits": "Follows PyTorch best practices and handles hooks properly.",
1578                             "strategies": "Replace direct forward() calls with the recommended calling syntax."},
1579                            {"name": "Gradient Clear Checker",
1580                             "how_to_fix": "Clear gradients before each backward pass using optimizer.zero_grad().",
1581                             "benefits": "Prevents gradient accumulation and ensures correct updates.",
1582                             "strategies": "Add optimizer.zero_grad() before loss.backward() in training loop."},
1583                            {"name": "Batch Normalisation Checker",
1584                             "how_to_fix": "Include BatchNorm layers in your model architecture.",
1585                             "benefits": "Improves training stability and model convergence.",
1586                             "strategies": "Add torch.nn.BatchNorm layers after convolutional or linear layers."},
1587                            {"name": "Dropout Usage Checker",
1588                             "how_to_fix": "Include Dropout layers for regularization.",
1589                             "benefits": "Reduces overfitting and improves model generalization.",
1590                             "strategies": "Add torch.nn.Dropout layers in your model architecture."},
1591                            {"name": "Data Augmentation Checker",
1592                             "how_to_fix": "Implement data augmentation using torchvision.transforms.",
1593                             "benefits": "Improves model robustness and reduces overfitting.",
1594                             "strategies": "Use torchvision.transforms to apply various augmentation techniques."},
1595                            {"name": "Learning Rate Scheduler Checker",
1596                             "how_to_fix": "Implement learning rate scheduling.",
1597                             "benefits": "Improves training convergence and final model performance.",
1598                             "strategies": "Use torch.optim.lr_scheduler for dynamic learning rate adjustment."},
1599                            {"name": "Logging Checker",
1600                             "how_to_fix": "Implement proper logging using tensorboard or similar tools.",
1601                             "benefits": "Enables better experiment tracking and debugging.",
1602                             "strategies": "Use tensorboardX or torch.utils.tensorboard for logging."},
1603                            {"name": "Model Evaluation Checker",
1604                             "how_to_fix": "Set model to evaluation mode using model.eval().",
1605                             "benefits": "Ensures correct behavior of layers like BatchNorm and Dropout during inference.",
1606                             "strategies": "Call model.eval() before validation/testing and model.train() before training."}],
1607                "TensorFlow": [{"name": "Randomness Control Checker",
1608                                "how_to_fix": "Set random seed using tf.random.set_seed().",
1609                                "benefits": "Ensures reproducible results across different runs.",
1610                                "strategies": "Add tf.random.set_seed(seed_value) at the start of your script."},
1611                               {"name": "Early Stopping Checker",
1612                                "how_to_fix": "Implement early stopping using tf.keras.callbacks.EarlyStopping.",
1613                                "benefits": "Prevents overfitting and reduces unnecessary training time.",
1614                                "strategies": "Add EarlyStopping callback to model.fit()."},
1615                               {"name": "Checkpointing Checker",
1616                                "how_to_fix": "Implement model checkpointing using tf.keras.callbacks.ModelCheckpoint.",
1617                                "benefits": "Enables model recovery and saves best performing models.",
1618                                "strategies": "Add ModelCheckpoint callback to model.fit()."},
1619                               {"name": "Memory Release Checker",
1620                                "how_to_fix": "Clear session after model creation/training using tf.keras.backend.clear_session().",
1621                                "benefits": "Prevents memory leaks and reduces resource usage.",
1622                                "strategies": "Call clear_session() after completing major operations."},
1623                               {"name": "Mask Missing Checker",
1624                                "how_to_fix": "Use appropriate masking for log operations.",
1625                                "benefits": "Prevents numerical errors and improves model stability.",
1626                                "strategies": "Apply masks before log operations to handle invalid values."},
1627                               {"name": "Tensor Array Checker",
1628                                "how_to_fix": "Use tf.TensorArray for dynamic tensor operations.",
1629                                "benefits": "Improves memory efficiency for dynamic computations.",
1630                                "strategies": "Replace Python lists with tf.TensorArray for growing tensors."},
1631                               {"name": "Dependent Threshold Checker",
1632                                "how_to_fix": "Use threshold-independent metrics like AUC.",
1633                                "benefits": "Provides more robust model evaluation.",
1634                                "strategies": "Include tf.keras.metrics.AUC in your model metrics."},
1635                               {"name": "Logging Checker",
1636                                "how_to_fix": "Implement TensorBoard logging.",
1637                                "benefits": "Enables better experiment tracking and visualization.",
1638                                "strategies": "Use tf.summary or TensorBoard callbacks for logging."},
1639                               {"name": "Batch Normalisation Checker",
1640                                "how_to_fix": "Include BatchNormalization layers in your model.",
1641                                "benefits": "Improves training stability and model convergence.",
1642                                "strategies": "Add tf.keras.layers.BatchNormalization layers to your model."},
1643                               {"name": "Dropout Usage Checker",
1644                                "how_to_fix": "Include Dropout layers for regularization.",
1645                                "benefits": "Reduces overfitting and improves generalization.",
1646                                "strategies": "Add tf.keras.layers.Dropout layers to your model."},
1647                               {"name": "Data Augmentation Checker",
1648                                "how_to_fix": "Implement data augmentation using ImageDataGenerator.",
1649                                "benefits": "Improves model robustness and reduces overfitting.",
1650                                "strategies": "Use tf.keras.preprocessing.image.ImageDataGenerator for augmentation."},
1651                               {"name": "Learning Rate Scheduler Checker",
1652                                "how_to_fix": "Implement learning rate scheduling.",
1653                                "benefits": "Improves training convergence and final model performance.",
1654                                "strategies": "Use LearningRateScheduler callback or custom scheduling."},
1655                               {"name": "Model Evaluation Checker",
1656                                "how_to_fix": "Evaluate model performance using model.evaluate().",
1657                                "benefits": "Provides standardized model evaluation.",
1658                                "strategies": "Use model.evaluate() on validation/test data."},
1659                               {"name": "Unit Testing Checker",
1660                                "how_to_fix": "Implement unit tests using tf.test.TestCase.",
1661                                "benefits": "Ensures code reliability and prevents regressions.",
1662                                "strategies": "Create test classes inheriting from tf.test.TestCase."},
1663                               {"name": "Exception Handling Checker",
1664                                "how_to_fix": "Implement proper exception handling.",
1665                                "benefits": "Improves code robustness and error reporting.",
1666                                "strategies": "Use try-except blocks to handle TensorFlow-specific exceptions."}]}