package mill.javalib.zinc

import mill.api.JsonFormatters.*
import mill.api.PathRef
import mill.api.daemon.internal.CompileProblemReporter
import mill.api.daemon.{Logger, Result}
import mill.client.lock.*
import mill.javalib.api.internal.*
import mill.javalib.api.{CompilationResult, JvmWorkerUtil, Versions}
import mill.javalib.api.internal.ZincCompilerBridgeProvider
import mill.javalib.api.internal.ZincCompilerBridgeProvider.AcquireResult
import mill.javalib.worker.*
import mill.javalib.zinc.ZincWorker.*
import mill.util.{CachedFactory, CachedFactoryWithInitData, RefCountedClassLoaderCache}
import sbt.internal.inc
import sbt.internal.inc.*
import sbt.internal.inc.classpath.ClasspathUtil
import sbt.internal.inc.consistent.ConsistentFileAnalysisStore
import sbt.internal.util.ConsoleOut
import sbt.mill.SbtLoggerUtils
import xsbti.compile.*
import xsbti.compile.analysis.ReadWriteMappers
import xsbti.{PathBasedFile, VirtualFile}

import java.io.File
import java.net.URLClassLoader
import java.util.Optional
import scala.collection.mutable

/**
 * @param jobs number of parallel jobs
 * @param useFileLocks use file-based locking instead of PID-based locking
 */
class ZincWorker(jobs: Int, useFileLocks: Boolean = false) extends AutoCloseable { self =>
  private val incrementalCompiler = new sbt.internal.inc.IncrementalCompilerImpl()
  private val compilerBridgeLocks: mutable.Map[String, MemoryLock] = mutable.Map.empty

  private val classloaderCache = new RefCountedClassLoaderCache(
    sharedLoader = getClass.getClassLoader,
    sharedPrefixes = Seq("xsbti")
  ) {
    override def extraRelease(cl: ClassLoader): Unit = {
      for {
        cls <- {
          try Some(cl.loadClass("scala.tools.nsc.classpath.FileBasedCache$"))
          catch {
            case _: ClassNotFoundException => None
          }
        }
        moduleField <- {
          try Some(cls.getField("MODULE$"))
          catch {
            case _: NoSuchFieldException => None
          }
        }
        module = moduleField.get(null)
        timerField <- {
          try Some(cls.getDeclaredField("scala$tools$nsc$classpath$FileBasedCache$$timer"))
          catch {
            case _: NoSuchFieldException => None
          }
        }
        _ = timerField.setAccessible(true)
        timerOpt0 = timerField.get(module)
        getOrElseMethod <- timerOpt0.getClass.getMethods.find(_.getName == "getOrElse")
        timer <-
          Option(getOrElseMethod.invoke(timerOpt0, null).asInstanceOf[java.util.Timer])
      } {
        timer.cancel()
      }
    }
  }

  private val scalaCompilerCache =
    new CachedFactoryWithInitData[
      ScalaCompilerCacheKey,
      ZincCompilerBridgeProvider,
      ScalaCompilerCached
    ] {
      override def maxCacheSize: Int = jobs

      override def setup(
          key: ScalaCompilerCacheKey,
          compilerBridgeProvider: ZincCompilerBridgeProvider
      ): ScalaCompilerCached = {
        import key.*

        val combinedCompilerJars = combinedCompilerClasspath.iterator.map(_.path.toIO).toArray

        val compiledCompilerBridge = compilerBridgeOpt.map(_.path).getOrElse {
          compileBridgeIfNeeded(
            scalaVersion,
            scalaOrganization,
            compilerClasspath.map(_.path),
            compilerBridgeProvider
          )
        }
        val classLoader = classloaderCache.get(key.combinedCompilerClasspath)
        val scalaInstance = new inc.ScalaInstance(
          version = key.scalaVersion,
          loader = classLoader,
          loaderCompilerOnly = classLoader,
          loaderLibraryOnly = ClasspathUtil.rootLoader,
          libraryJars = Array(libraryJarNameGrep(
            compilerClasspath,
            // if Dotty or Scala 3.0 - 3.7, use the 2.13 version of the standard library
            if (JvmWorkerUtil.enforceScala213Library(key.scalaVersion)) "2.13."
            // otherwise use the library matching the Scala version
            else key.scalaVersion
          ).path.toIO),
          compilerJars = combinedCompilerJars,
          allJars = combinedCompilerJars,
          explicitActual = None
        )
        val compilers = incrementalCompiler.compilers(
          javaTools = getLocalOrCreateJavaTools(),
          scalac = ZincUtil.scalaCompiler(scalaInstance, compiledCompilerBridge.toIO)
        )
        ScalaCompilerCached(classLoader, compilers)
      }

      override def teardown(key: ScalaCompilerCacheKey, value: ScalaCompilerCached): Unit = {
        classloaderCache.release(key.combinedCompilerClasspath)
      }
    }

  private val javaOnlyCompilerCache = new CachedFactory[JavaCompilerCacheKey, Compilers] {

    override def setup(key: JavaCompilerCacheKey): Compilers = {
      // Only options relevant for the compiler runtime influence the cached instance
      // Keep the classpath as written by the user
      val classpathOptions = ClasspathOptions.of(
        /*bootLibrary*/ false,
        /*compiler*/ false,
        /*extra*/ false,
        /*autoBoot*/ false,
        /*filterLibrary*/ false
      )

      val dummyFile = new java.io.File("")
      // Zinc does not have an entry point for Java-only compilation, so we need
      // to make up a dummy ScalaCompiler instance.
      val scalac = ZincUtil.scalaCompiler(
        new inc.ScalaInstance(
          version = "",
          loader = null,
          loaderCompilerOnly = null,
          loaderLibraryOnly = null,
          libraryJars = Array(dummyFile),
          compilerJars = Array(dummyFile),
          allJars = new Array(0),
          explicitActual = Some("")
        ),
        dummyFile,
        classpathOptions // this is used for javac too
      )

      val javaTools = getLocalOrCreateJavaTools()
      val compilers = incrementalCompiler.compilers(javaTools, scalac)
      compilers
    }

    override def teardown(key: JavaCompilerCacheKey, value: Compilers): Unit = ()

    override def maxCacheSize: Int = jobs
  }

  def compileJava(
      op: ZincOp.CompileJava,
      reporter: Option[CompileProblemReporter],
      reportCachedProblems: Boolean,
      localConfig: ZincWorker.LocalConfig,
      processConfig: ZincWorker.ProcessConfig
  ): Result[CompilationResult] = {
    val cacheKey = JavaCompilerCacheKey(op.javacOptions)
    javaOnlyCompilerCache.withValue(cacheKey) { compilers =>
      compileInternal(
        upstreamCompileOutput = op.upstreamCompileOutput,
        sources = op.sources,
        compileClasspath = op.compileClasspath,
        javacOptions = op.javacOptions,
        scalacOptions = Nil,
        compilers = compilers,
        reporter = reporter,
        reportCachedProblems = reportCachedProblems,
        incrementalCompilation = op.incrementalCompilation,
        auxiliaryClassFileExtensions = Seq.empty,
        localConfig = localConfig,
        processConfig = processConfig,
        workDir = op.workDir
      )
    }
  }

  def compileMixed(
      op: ZincOp.CompileMixed,
      reporter: Option[CompileProblemReporter],
      reportCachedProblems: Boolean,
      localConfig: ZincWorker.LocalConfig,
      processConfig: ZincWorker.ProcessConfig
  ): Result[CompilationResult] = {
    withScalaCompilers(
      scalaVersion = op.scalaVersion,
      scalaOrganization = op.scalaOrganization,
      compilerClasspath = op.compilerClasspath,
      scalacPluginClasspath = op.scalacPluginClasspath,
      compilerBridgeOpt = op.compilerBridgeOpt,
      javacOptions = op.javacOptions,
      processConfig.compilerBridgeProvider
    ) { compilers =>
      compileInternal(
        upstreamCompileOutput = op.upstreamCompileOutput,
        sources = op.sources,
        compileClasspath = op.compileClasspath,
        javacOptions = op.javacOptions,
        scalacOptions = op.scalacOptions,
        compilers = compilers,
        reporter = reporter,
        reportCachedProblems = reportCachedProblems,
        incrementalCompilation = op.incrementalCompilation,
        auxiliaryClassFileExtensions = op.auxiliaryClassFileExtensions,
        localConfig = localConfig,
        processConfig = processConfig,
        workDir = op.workDir
      )
    }
  }

  def scaladocJar(
      op: ZincOp.ScaladocJar,
      processConfig: ZincWorker.ProcessConfig
  ): Boolean = {
    withScalaCompilers(
      scalaVersion = op.scalaVersion,
      scalaOrganization = op.scalaOrganization,
      compilerClasspath = op.compilerClasspath,
      scalacPluginClasspath = op.scalacPluginClasspath,
      compilerBridgeOpt = op.compilerBridgeOpt,
      javacOptions = Nil,
      compilerBridgeProvider = processConfig.compilerBridgeProvider
    ) { compilers =>
      // Not sure why dotty scaladoc is flaky, but add retries to workaround it
      // https://github.com/com-lihaoyi/mill/issues/4556
      mill.util.Retry(
        count = 2,
        failWithFirstError = true,
        logger = msg => processConfig.log.debug(msg)
      ) {
        if (
          JvmWorkerUtil.isDotty(op.scalaVersion) || JvmWorkerUtil.isScala3Milestone(op.scalaVersion)
        ) {
          // dotty 0.x and scala 3 milestones use the dotty-doc tool
          val dottydocClass =
            compilers.scalac().scalaInstance().loader().loadClass(
              "dotty.tools.dottydoc.DocDriver"
            )
          val dottydocMethod = dottydocClass.getMethod("process", classOf[Array[String]])
          val reporter =
            dottydocMethod.invoke(dottydocClass.getConstructor().newInstance(), op.args.toArray)
          val hasErrorsMethod = reporter.getClass.getMethod("hasErrors")
          !hasErrorsMethod.invoke(reporter).asInstanceOf[Boolean]
        } else if (JvmWorkerUtil.isScala3(op.scalaVersion)) {
          // DottyDoc makes use of `com.fasterxml.jackson.databind.Module` which
          // requires the ContextClassLoader to be set appropriately
          mill.api.daemon.ClassLoader.withContextClassLoader(this.getClass.getClassLoader) {
            val scaladocClass =
              compilers.scalac().scalaInstance().loader().loadClass("dotty.tools.scaladoc.Main")

            val scaladocMethod = scaladocClass.getMethod("run", classOf[Array[String]])
            val reporter =
              scaladocMethod.invoke(scaladocClass.getConstructor().newInstance(), op.args.toArray)
            val hasErrorsMethod = reporter.getClass.getMethod("hasErrors")
            !hasErrorsMethod.invoke(reporter).asInstanceOf[Boolean]
          }
        } else {
          val scaladocClass =
            compilers.scalac().scalaInstance().loader().loadClass("scala.tools.nsc.ScalaDoc")
          val scaladocMethod = scaladocClass.getMethod("process", classOf[Array[String]])
          scaladocMethod.invoke(
            scaladocClass.getConstructor().newInstance(),
            op.args.toArray
          ).asInstanceOf[Boolean]
        }
      }
    }
  }

  def close(): Unit = {
    scalaCompilerCache.close()
    javaOnlyCompilerCache.close()
    classloaderCache.close()
  }

  private def withScalaCompilers[T](
      scalaVersion: String,
      scalaOrganization: String,
      compilerClasspath: Seq[PathRef],
      scalacPluginClasspath: Seq[PathRef],
      compilerBridgeOpt: Option[PathRef],
      javacOptions: Seq[String],
      compilerBridgeProvider: ZincCompilerBridgeProvider
  )(f: Compilers => T) = {
    val cacheKey = ScalaCompilerCacheKey(
      scalaVersion,
      compilerClasspath,
      scalacPluginClasspath,
      compilerBridgeOpt,
      scalaOrganization,
      javacOptions
    )
    scalaCompilerCache.withValue(cacheKey, compilerBridgeProvider) { cached =>
      f(cached.compilers)
    }
  }

  private def compileInternal(
      upstreamCompileOutput: Seq[CompilationResult],
      sources: Seq[os.Path],
      compileClasspath: Seq[os.Path],
      javacOptions: Seq[String],
      scalacOptions: Seq[String],
      compilers: Compilers,
      reporter: Option[CompileProblemReporter],
      reportCachedProblems: Boolean,
      incrementalCompilation: Boolean,
      auxiliaryClassFileExtensions: Seq[String],
      zincCache: os.SubPath = os.sub / "zinc",
      localConfig: ZincWorker.LocalConfig,
      processConfig: ZincWorker.ProcessConfig,
      workDir: os.Path
  ): Result[CompilationResult] = {

    os.makeDir.all(workDir)

    val classesDir = workDir / "classes"

    if (localConfig.logDebugEnabled) {
      processConfig.log.debug(
        s"""Compiling:
           |  javacOptions: ${javacOptions.map("'" + _ + "'").mkString(" ")}
           |  scalacOptions: ${scalacOptions.map("'" + _ + "'").mkString(" ")}
           |  sources: ${sources.map("'" + _ + "'").mkString(" ")}
           |  classpath: ${compileClasspath.map("'" + _ + "'").mkString(" ")}
           |  output: $classesDir"""
          .stripMargin
      )
    }

    reporter.foreach(_.start())

    val consoleAppender = SbtLoggerUtils.ConciseLevelConsoleAppender(
      name = "ZincLogAppender",
      log = s => processConfig.log.info(s),
      ansiCodesSupported0 = localConfig.logPromptColored
    )
    val loggerId = Thread.currentThread().getId.toString
    val zincLogLevel =
      if (localConfig.logDebugEnabled) sbt.util.Level.Debug else sbt.util.Level.Info
    val logger = SbtLoggerUtils.createLogger(loggerId, consoleAppender, zincLogLevel)

    val maxErrors = reporter.map(_.maxErrors).getOrElse(CompileProblemReporter.defaultMaxErrors)

    val analysisMap0 = upstreamCompileOutput.map(c => c.classes.path -> c.analysisFile).toMap

    def analysisMap(f: VirtualFile): Optional[CompileAnalysis] = {
      val analysisFile = f match {
        case pathBased: PathBasedFile => analysisMap0.get(os.Path(pathBased.toPath))
        case _ => None
      }
      analysisFile match {
        case Some(zincPath) => fileAnalysisStore(zincPath).get().map(_.getAnalysis)
        case None => Optional.empty[CompileAnalysis]
      }
    }

    val lookup = MockedLookup(analysisMap)

    val store = fileAnalysisStore(workDir / zincCache)

    // Fix jdk classes marked as binary dependencies, see https://github.com/com-lihaoyi/mill/pull/1904
    val converter = MappedFileConverter.empty
    val classpath = (compileClasspath.iterator ++ Some(classesDir))
      .map(path => converter.toVirtualFile(path.toNIO))
      .toArray
    val virtualSources = sources.iterator
      .map(path => converter.toVirtualFile(path.toNIO))
      .toArray

    val incOptions = IncOptions.of().withAuxiliaryClassFiles(
      auxiliaryClassFileExtensions.map(new AuxiliaryClassFileExtension(_)).toArray
    )
    val compileProgress = reporter.map { reporter =>
      new CompileProgress {
        override def advance(
            current: Int,
            total: Int,
            prevPhase: String,
            nextPhase: String
        ): Boolean = {
          reporter.notifyProgress(progress = current, total = total)
          true
        }
      }
    }

    val addColorNeverOption = Option.when(
      !localConfig.logPromptColored &&
        compilers.scalac().scalaInstance().version().startsWith("3.") &&
        // might be too broad
        !scalacOptions.exists(_.startsWith("-color:"))
    ) {
      "-color:never"
    }

    val finalScalacOptions = addColorNeverOption.toSeq ++ scalacOptions

    val (originalSourcesMap, posMapperOpt) = PositionMapper.create(virtualSources)

    val newReporter = reporter match {
      case None =>
        new ManagedLoggedReporter(maxErrors, logger) with RecordingReporter
          with TransformingReporter(
            localConfig.logPromptColored,
            posMapperOpt.orNull,
            localConfig.workspaceRoot
          ) {}
      case Some(forwarder) =>
        new ManagedLoggedReporter(maxErrors, logger)
          with ForwardingReporter(forwarder)
          with RecordingReporter
          with TransformingReporter(
            localConfig.logPromptColored,
            posMapperOpt.orNull,
            localConfig.workspaceRoot
          ) {}
    }

    val inputs = incrementalCompiler.inputs(
      classpath = classpath,
      sources = virtualSources,
      classesDirectory = classesDir.toNIO,
      earlyJarPath = None,
      scalacOptions = finalScalacOptions.toArray,
      javacOptions = javacOptions.toArray,
      maxErrors = maxErrors,
      sourcePositionMappers = Array(),
      order = CompileOrder.Mixed,
      compilers = compilers,
      setup = incrementalCompiler.setup(
        lookup = lookup,
        skip = false,
        cacheFile = zincCache.toNIO,
        cache = new FreshCompilerCache,
        incOptions = incOptions,
        reporter = newReporter,
        progress = compileProgress,
        earlyAnalysisStore = None,
        extra = Array()
      ),
      pr = if (incrementalCompilation) {
        val prev = store.get()
        PreviousResult.of(prev.map(_.getAnalysis), prev.map(_.getMiniSetup))
      } else {
        PreviousResult.of(Optional.empty[CompileAnalysis], Optional.empty[MiniSetup])
      },
      temporaryClassesDirectory = java.util.Optional.empty(),
      converter = converter,
      stampReader = Stamps.timeWrapBinaryStamps(converter)
    )

    val scalaColorProp = "scala.color"
    val previousScalaColor = sys.props(scalaColorProp)
    try {
      sys.props(scalaColorProp) = if (localConfig.logPromptColored) "true" else "false"

      val newResult = incrementalCompiler.compile(in = inputs, logger = logger)

      if (reportCachedProblems) newReporter.logOldProblems(newResult.analysis())

      store.set(AnalysisContents.create(newResult.analysis(), newResult.setup()))

      Result.Success(CompilationResult(workDir / zincCache, PathRef(classesDir)))
    } catch {
      case e: CompileFailed =>
        Result.Failure(e.toString)
    } finally {
      for (rep <- reporter) {
        for (f <- sources) {
          rep.fileVisited(f.toNIO)
          for (f0 <- originalSourcesMap.get(f)) rep.fileVisited(f0.toNIO)
        }
        rep.finish()
      }
      previousScalaColor match {
        case null => sys.props.remove(scalaColorProp)
        case _ => sys.props(scalaColorProp) = previousScalaColor
      }
    }
  }

  /**
   * If needed, compile (for Scala 2) or download (for Dotty) the compiler bridge.
   *
   * @return a path to the directory containing the compiled classes, or to the downloaded jar file
   */
  private def compileBridgeIfNeeded(
      scalaVersion: String,
      scalaOrganization: String,
      compilerClasspath: Seq[os.Path],
      compilerBridgeProvider: ZincCompilerBridgeProvider
  ): os.Path = {
    val workingDir = compilerBridgeProvider.workspace / s"zinc-${Versions.zinc}" / scalaVersion

    os.makeDir.all(compilerBridgeProvider.workspace / "compiler-bridge-locks")
    val memoryLock = synchronized(
      compilerBridgeLocks.getOrElseUpdate(scalaVersion, new MemoryLock)
    )
    val compiledDest = workingDir / "compiled"
    val doneFile = compiledDest / "DONE"
    // Use a double-lock here because we need mutex both between threads within this
    // process, as well as between different processes since sometimes we are initializing
    // the compiler bridge inside a separate `ZincWorkerMain` subprocess
    val doubleLock = new DoubleLock(
      memoryLock,
      Lock.forDirectory(
        (compilerBridgeProvider.workspace / "compiler-bridge-locks" / scalaVersion).toString,
        useFileLocks
      )
    )
    try {
      doubleLock.lock()
      if (os.exists(doneFile)) compiledDest
      else {
        val acquired =
          compilerBridgeProvider.acquire(
            scalaVersion = scalaVersion,
            scalaOrganization = scalaOrganization
          )

        acquired match {
          case AcquireResult.Compiled(bridgeJar) => bridgeJar
          case AcquireResult.NotCompiled(bridgeClasspath, bridgeSourcesJar) =>
            ZincCompilerBridgeProvider.compile(
              compilerBridgeProvider.logInfo,
              workingDir,
              compiledDest,
              scalaVersion,
              compilerClasspath,
              bridgeClasspath,
              bridgeSourcesJar
            )
            os.write(doneFile, "")
            compiledDest
        }
      }
    } finally doubleLock.close()
  }

  def apply(
      op: ZincOp,
      reporter: Option[CompileProblemReporter],
      reportCachedProblems: Boolean,
      localConfig: ZincWorker.LocalConfig,
      processConfig: ZincWorker.ProcessConfig
  ): op.Response = {
    op match {
      case msg: ZincOp.CompileJava =>
        compileJava(
          msg,
          reporter,
          reportCachedProblems,
          localConfig,
          processConfig
        ).asInstanceOf[op.Response]

      case msg: ZincOp.CompileMixed =>
        compileMixed(
          msg,
          reporter,
          reportCachedProblems,
          localConfig,
          processConfig
        ).asInstanceOf[op.Response]

      case msg: ZincOp.ScaladocJar =>
        scaladocJar(msg, processConfig).asInstanceOf[op.Response]

      case msg: ZincOp.DiscoverTests =>
        mill.javalib.testrunner.DiscoverTests(msg).asInstanceOf[op.Response]

      case msg: ZincOp.GetTestTasks =>
        mill.javalib.testrunner.GetTestTasks(msg).asInstanceOf[op.Response]

      case msg: ZincOp.DiscoverJunit5Tests =>
        mill.javalib.testrunner.DiscoverJunit5Tests(msg).asInstanceOf[op.Response]
    }
  }
}

object ZincWorker {

  /**
   * Dependencies of the invocation.
   *
   * Can come either from the local [[ZincWorker]] running in [[JvmWorkerImpl]] or from a zinc worker running
   * in a different process.
   */
  case class ProcessConfig(
      log: Logger.Actions,
      consoleOut: ConsoleOut,
      compilerBridgeProvider: ZincCompilerBridgeProvider
  )

  /** The invocation context, always comes from the Mill's process. */
  case class LocalConfig(
      dest: os.Path,
      logDebugEnabled: Boolean,
      logPromptColored: Boolean,
      workspaceRoot: os.Path
  ) derives upickle.ReadWriter

  private case class ScalaCompilerCacheKey(
      scalaVersion: String,
      compilerClasspath: Seq[PathRef],
      scalacPluginClasspath: Seq[PathRef],
      compilerBridgeOpt: Option[PathRef],
      scalaOrganization: String,
      javacOptions: Seq[String]
  ) {
    val combinedCompilerClasspath: Seq[PathRef] = compilerClasspath ++ scalacPluginClasspath
  }

  private case class ScalaCompilerCached(classLoader: URLClassLoader, compilers: Compilers)

  private case class JavaCompilerCacheKey(javacOptions: Seq[String])

  private def getLocalOrCreateJavaTools(): JavaTools = {
    val compiler = javac.JavaCompiler.local.getOrElse(javac.JavaCompiler.fork())
    val docs = javac.Javadoc.local.getOrElse(javac.Javadoc.fork())
    javac.JavaTools(compiler, docs)
  }

  private def libraryJarNameGrep(compilerClasspath: Seq[PathRef], scalaVersion: String): PathRef =
    JvmWorkerUtil.grepJar(compilerClasspath, "scala-library", scalaVersion, sources = false)

  private def fileAnalysisStore(path: os.Path): AnalysisStore =
    ConsistentFileAnalysisStore.binary(
      file = path.toIO,
      mappers = ReadWriteMappers.getEmptyMappers,
      reproducible = true,
      // No need to utilize more than 8 cores to serialize a small file
      parallelism = math.min(Runtime.getRuntime.availableProcessors(), 8)
    )
}
