diff --git a/examples/ios/FLiOS.xcodeproj/project.pbxproj b/examples/ios/FLiOS.xcodeproj/project.pbxproj index d8e4d6da207..1dc77b89bc6 100644 --- a/examples/ios/FLiOS.xcodeproj/project.pbxproj +++ b/examples/ios/FLiOS.xcodeproj/project.pbxproj @@ -11,7 +11,8 @@ 19764EED297D3ABE009F3E5D /* BenchmarkModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19764EEC297D3ABE009F3E5D /* BenchmarkModel.swift */; }; 19764EEF297D3AC9009F3E5D /* BenchmarkSuite.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19764EEE297D3AC9009F3E5D /* BenchmarkSuite.swift */; }; 21075DC929788ED400311461 /* FLiOSModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 21075DC829788ED400311461 /* FLiOSModel.swift */; }; - 216A90382A0CBE5400E0B532 /* flwr in Frameworks */ = {isa = PBXBuildFile; productRef = 216A90372A0CBE5400E0B532 /* flwr */; }; + 210BBB812AA5D10C00B283AC /* MLFlwrClientModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 210BBB802AA5D10C00B283AC /* MLFlwrClientModel.swift */; }; + 210BBB822AA5D10C00B283AC /* MLFlwrClient.swift in Sources */ = {isa = PBXBuildFile; fileRef = 210BBB7F2AA5D10C00B283AC /* MLFlwrClient.swift */; }; 21A77AC32865B22B0062EBD8 /* Array+Extensions.swift in Sources */ = {isa = PBXBuildFile; fileRef = 21A77A6F2865B22A0062EBD8 /* Array+Extensions.swift */; }; 21A77AC42865B22B0062EBD8 /* Math.swift in Sources */ = {isa = PBXBuildFile; fileRef = 21A77A702865B22A0062EBD8 /* Math.swift */; }; 21A77AC52865B22B0062EBD8 /* UIImage+CVPixelBuffer.swift in Sources */ = {isa = PBXBuildFile; fileRef = 21A77A712865B22A0062EBD8 /* UIImage+CVPixelBuffer.swift */; }; @@ -127,6 +128,9 @@ 19764EEC297D3ABE009F3E5D /* BenchmarkModel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = BenchmarkModel.swift; sourceTree = ""; }; 19764EEE297D3AC9009F3E5D /* BenchmarkSuite.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = BenchmarkSuite.swift; sourceTree = ""; }; 21075DC829788ED400311461 /* FLiOSModel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FLiOSModel.swift; sourceTree = ""; }; + 210BBB7E2AA5D04500B283AC /* flwr */ = {isa = PBXFileReference; lastKnownFileType = wrapper; name = flwr; path = ../../src/swift/flwr; sourceTree = ""; }; + 210BBB7F2AA5D10C00B283AC /* MLFlwrClient.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MLFlwrClient.swift; sourceTree = ""; }; + 210BBB802AA5D10C00B283AC /* MLFlwrClientModel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MLFlwrClientModel.swift; sourceTree = ""; }; 21A77A6F2865B22A0062EBD8 /* Array+Extensions.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = "Array+Extensions.swift"; sourceTree = ""; }; 21A77A702865B22A0062EBD8 /* Math.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Math.swift; sourceTree = ""; }; 21A77A712865B22A0062EBD8 /* UIImage+CVPixelBuffer.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = "UIImage+CVPixelBuffer.swift"; sourceTree = ""; }; @@ -229,7 +233,6 @@ buildActionMask = 2147483647; files = ( 50B5A2C6294C848000B0ABC6 /* flwr in Frameworks */, - 216A90382A0CBE5400E0B532 /* flwr in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -246,6 +249,14 @@ path = Benchmark; sourceTree = ""; }; + 210BBB7D2AA5D04500B283AC /* Packages */ = { + isa = PBXGroup; + children = ( + 210BBB7E2AA5D04500B283AC /* flwr */, + ); + name = Packages; + sourceTree = ""; + }; 21A77A6E2865B22A0062EBD8 /* CoreMLHelpers */ = { isa = PBXGroup; children = ( @@ -346,6 +357,8 @@ 21A77B162865B23D0062EBD8 /* CoreMLClient */ = { isa = PBXGroup; children = ( + 210BBB7F2AA5D10C00B283AC /* MLFlwrClient.swift */, + 210BBB802AA5D10C00B283AC /* MLFlwrClientModel.swift */, 50CDF0E0299D807300AEBD55 /* DataLoader.swift */, 21A77B202865B3B10062EBD8 /* MLModelInspect.swift */, 50EB97C929AC9C94000AC771 /* Constants.swift */, @@ -363,6 +376,7 @@ 21B540642865B14000185DEE = { isa = PBXGroup; children = ( + 210BBB7D2AA5D04500B283AC /* Packages */, 21B5406F2865B14000185DEE /* FLiOS */, 21B5406E2865B14000185DEE /* Products */, 21A77B2E2865B5420062EBD8 /* Frameworks */, @@ -440,7 +454,6 @@ name = FLiOS; packageProductDependencies = ( 50B5A2C5294C848000B0ABC6 /* flwr */, - 216A90372A0CBE5400E0B532 /* flwr */, ); productName = FlowerCoreML; productReference = 21B5406D2865B14000185DEE /* FLiOS.app */; @@ -471,7 +484,6 @@ ); mainGroup = 21B540642865B14000185DEE; packageReferences = ( - 216A90362A0CBE5400E0B532 /* XCRemoteSwiftPackageReference "flower-swift" */, ); productRefGroup = 21B5406E2865B14000185DEE /* Products */; projectDirPath = ""; @@ -522,6 +534,7 @@ 21A77ADC2865B22B0062EBD8 /* BayesianProbitRegressor.proto in Sources */, 21A77ADE2865B22B0062EBD8 /* Gazetteer.proto in Sources */, 21A77ADB2865B22B0062EBD8 /* SoundAnalysisPreprocessing.pb.swift in Sources */, + 210BBB822AA5D10C00B283AC /* MLFlwrClient.swift in Sources */, 21A77AC32865B22B0062EBD8 /* Array+Extensions.swift in Sources */, 21A77B152865B22B0062EBD8 /* Scaler.pb.swift in Sources */, 21A77B092865B22B0062EBD8 /* Parameters.pb.swift in Sources */, @@ -585,6 +598,7 @@ 21A77AF72865B22B0062EBD8 /* ItemSimilarityRecommender.pb.swift in Sources */, 21A77AD92865B22B0062EBD8 /* FeatureVectorizer.proto in Sources */, 21A77B012865B22B0062EBD8 /* TextClassifier.proto in Sources */, + 210BBB812AA5D10C00B283AC /* MLFlwrClientModel.swift in Sources */, 21A77B0F2865B22B0062EBD8 /* LinkedModel.proto in Sources */, 21A77ADF2865B22B0062EBD8 /* MIL.pb.swift in Sources */, 21A77AE52865B22B0062EBD8 /* GLMRegressor.pb.swift in Sources */, @@ -804,23 +818,7 @@ }; /* End XCConfigurationList section */ -/* Begin XCRemoteSwiftPackageReference section */ - 216A90362A0CBE5400E0B532 /* XCRemoteSwiftPackageReference "flower-swift" */ = { - isa = XCRemoteSwiftPackageReference; - repositoryURL = "https://github.com/adap/flower-swift.git"; - requirement = { - branch = main; - kind = branch; - }; - }; -/* End XCRemoteSwiftPackageReference section */ - /* Begin XCSwiftPackageProductDependency section */ - 216A90372A0CBE5400E0B532 /* flwr */ = { - isa = XCSwiftPackageProductDependency; - package = 216A90362A0CBE5400E0B532 /* XCRemoteSwiftPackageReference "flower-swift" */; - productName = flwr; - }; 21A77B312865B55D0062EBD8 /* Flower */ = { isa = XCSwiftPackageProductDependency; productName = Flower; diff --git a/src/swift/flwr/Sources/Flower/CoreML/MLFlwrClient.swift b/examples/ios/FLiOS/CoreMLClient/MLFlwrClient.swift similarity index 92% rename from src/swift/flwr/Sources/Flower/CoreML/MLFlwrClient.swift rename to examples/ios/FLiOS/CoreMLClient/MLFlwrClient.swift index 4686a06be9c..2254113a551 100644 --- a/src/swift/flwr/Sources/Flower/CoreML/MLFlwrClient.swift +++ b/examples/ios/FLiOS/CoreMLClient/MLFlwrClient.swift @@ -10,6 +10,7 @@ import NIOCore import NIOPosix import CoreML import os +import flwr /// Set of CoreML machine learning task options. public enum MLTask { @@ -55,6 +56,7 @@ public class MLFlwrClient: Client { private var compiledModelUrl: URL private var tempModelUrl: URL + private var modelUrl: URL private let log = Logger(subsystem: Bundle.main.bundleIdentifier ?? "flwr.Flower", category: String(describing: MLFlwrClient.self)) @@ -65,8 +67,9 @@ public class MLFlwrClient: Client { /// - layerWrappers: A MLLayerWrapper struct that contains layer information. /// - dataLoader: A MLDataLoader struct that contains train- and testdata batches. /// - compiledModelUrl: An URL specifying the location or path of the compiled model. - public init(layerWrappers: [MLLayerWrapper], dataLoader: MLDataLoader, compiledModelUrl: URL) { - self.parameters = MLParameter(layerWrappers: layerWrappers) + public init(layerWrappers: [MLLayerWrapper], dataLoader: MLDataLoader, compiledModelUrl: URL, modelUrl: URL) { + self.modelUrl = modelUrl + self.parameters = MLParameter(layerWrappers: layerWrappers, modelUrl: modelUrl, compiledModelUrl: compiledModelUrl) self.eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) self.dataLoader = dataLoader self.compiledModelUrl = compiledModelUrl @@ -85,6 +88,7 @@ public class MLFlwrClient: Client { /// /// - Returns: Parameters from the local model public func getParameters() -> GetParametersRes { + parameters.initializeParameters() let parameters = parameters.weightsToParameters() let status = Status(code: .ok, message: String()) @@ -128,7 +132,7 @@ public class MLFlwrClient: Client { } let progressHandler: (MLUpdateContext) -> Void = { contextProgress in - let loss = String(format: "%.4f", contextProgress.metrics[.lossValue] as! Double) + let loss = String(format: "%.20f", contextProgress.metrics[.lossValue] as! Float) switch task { case .train: self.log.info("Epoch \(contextProgress.metrics[.epochIndex] as! Int + 1) finished with loss \(loss)") @@ -143,8 +147,8 @@ public class MLFlwrClient: Client { self.saveModel(finalContext) } - let loss = finalContext.metrics[.lossValue] as! Double - let result = MLResult(loss: loss, numSamples: dataset.count, accuracy: (1.0 - loss) * 100) + let loss = String(format: "%.20f", finalContext.metrics[.lossValue] as! Double) + let result = MLResult(loss: Double(loss)!, numSamples: dataset.count, accuracy: (1.0 - Double(loss)!) * 100) promise?.succeed(result) } diff --git a/src/swift/flwr/Sources/Flower/CoreML/MLFlwrClientModel.swift b/examples/ios/FLiOS/CoreMLClient/MLFlwrClientModel.swift similarity index 55% rename from src/swift/flwr/Sources/Flower/CoreML/MLFlwrClientModel.swift rename to examples/ios/FLiOS/CoreMLClient/MLFlwrClientModel.swift index e9b76672f8e..f13ce3a25a4 100644 --- a/src/swift/flwr/Sources/Flower/CoreML/MLFlwrClientModel.swift +++ b/examples/ios/FLiOS/CoreMLClient/MLFlwrClientModel.swift @@ -8,6 +8,11 @@ import Foundation import CoreML import os +import flwr + +typealias Model = CoreML_Specification_Model +typealias NeuralNetwork = CoreML_Specification_NeuralNetwork +typealias NeuralNetworkLayer = CoreML_Specification_NeuralNetworkLayer /// Container for train and test dataset. public struct MLDataLoader { @@ -47,6 +52,10 @@ public class MLParameter { private var parameterConverter = ParameterConverter.shared var layerWrappers: [MLLayerWrapper] + var model: Model? + let modelUrl: URL + let compiledModelUrl: URL + private let log = Logger(subsystem: Bundle.main.bundleIdentifier ?? "flwr.Flower", category: String(describing: MLParameter.self)) @@ -54,8 +63,11 @@ public class MLParameter { /// /// - Parameters: /// - layerWrappers: Information about the layer provided with primitive data types. - public init(layerWrappers: [MLLayerWrapper]) { + public init(layerWrappers: [MLLayerWrapper], modelUrl: URL, compiledModelUrl: URL) { self.layerWrappers = layerWrappers + self.modelUrl = modelUrl + self.compiledModelUrl = compiledModelUrl + self.model = try? Model(serializedData: try Data(contentsOf: modelUrl)) } /// Converts the Parameters structure to MLModelConfiguration to interface with CoreML. @@ -81,19 +93,28 @@ public class MLParameter { layerWrappers[index].weights = weightsArray if layerWrappers[index].isUpdatable { - if let weightsMultiArray = parameterConverter.dataToMultiArray(data: data) { - let weightsShape = weightsMultiArray.shape.map { Int16(truncating: $0) } - guard weightsShape == layerWrappers[index].shape else { - log.info("shape not the same") + guard model?.neuralNetwork.layers != nil else { + continue + } + for (indexB, neuralNetworkLayer) in model!.neuralNetwork.layers.enumerated() { + guard layerWrappers[index].name == neuralNetworkLayer.name else { + continue + } + switch neuralNetworkLayer.layer! { + case .convolution: + model!.neuralNetwork.layers[indexB].convolution.weights.floatValue = layerWrappers[index].weights + case .innerProduct: + model!.neuralNetwork.layers[indexB].innerProduct.weights.floatValue = layerWrappers[index].weights + default: + log.info("unexpected layer \(neuralNetworkLayer.name)") continue } - let paramKey = MLParameterKey.weights.scoped(to: layerWrappers[index].name) - config.parameters?[paramKey] = weightsMultiArray } } } } + exportModel() return config } @@ -108,6 +129,47 @@ public class MLParameter { return Parameters(tensors: dataArray, tensorType: "ndarray") } + private func exportModel() { + let modelFileName = modelUrl.deletingPathExtension().lastPathComponent + let fileManager = FileManager.default + let tempModelUrl = appDirectory.appendingPathComponent("temp\(modelFileName).mlmodel") + try? model?.serializedData().write(to: tempModelUrl) + if let compiledTempModelUrl = try? MLModel.compileModel(at: tempModelUrl) { + _ = try? fileManager.replaceItemAt(compiledModelUrl, withItemAt: compiledTempModelUrl) + } + } + + func initializeParameters() { + guard ((model?.neuralNetwork.layers) != nil) else { + return + } + for (indexA, neuralNetworkLayer) in model!.neuralNetwork.layers.enumerated() { + for (indexB, layer) in layerWrappers.enumerated() { + if layer.name != neuralNetworkLayer.name { continue } + switch neuralNetworkLayer.layer! { + case .convolution: + let convolution = neuralNetworkLayer.convolution + //shape definition = [outputChannels, kernelChannels, kernelHeight, kernelWidth] + let upperLower = Float(6.0 / Float(Int16(convolution.outputChannels) + Int16(convolution.kernelChannels) + Int16(convolution.kernelSize[0]) + Int16(convolution.kernelSize[1]))).squareRoot() + let initialise = (0..<(neuralNetworkLayer.convolution.weights.floatValue.count)).map { _ in Float.random(in: -upperLower...upperLower) } + model?.neuralNetwork.layers[indexA].convolution.weights.floatValue = initialise + layerWrappers[indexB].weights = initialise + case .innerProduct: + let innerProduct = neuralNetworkLayer.innerProduct + //shape definition = [C_out, C_in]. + let upperLower = Float(6.0 / Float(Int16(innerProduct.outputChannels) + Int16(innerProduct.inputChannels))).squareRoot() + let initialise = (0..<(neuralNetworkLayer.innerProduct.weights.floatValue.count)).map { _ in Float.random(in: -upperLower...upperLower) } + model?.neuralNetwork.layers[indexA].innerProduct.weights.floatValue = initialise + layerWrappers[indexB].weights = initialise + default: + log.info("unexpected layer \(neuralNetworkLayer.name)") + continue + } + } + } + exportModel() + } + /// Updates the layers given the CoreML update context /// /// - Parameters: diff --git a/examples/ios/FLiOS/FLiOSModel.swift b/examples/ios/FLiOS/FLiOSModel.swift index f5004a6b969..c12a968e976 100644 --- a/examples/ios/FLiOS/FLiOSModel.swift +++ b/examples/ios/FLiOS/FLiOSModel.swift @@ -131,7 +131,8 @@ public class FLiOSModel: ObservableObject { let layerWrappers = modelInspect.getLayerWrappers() self.mlFlwrClient = MLFlwrClient(layerWrappers: layerWrappers, dataLoader: dataLoader, - compiledModelUrl: compiledModelUrl) + compiledModelUrl: compiledModelUrl, + modelUrl: url) case .local: self.localClient = LocalClient(dataLoader: dataLoader, compiledModelUrl: compiledModelUrl) } @@ -203,7 +204,7 @@ class LocalClient { func runMLTask(statusHandler: @escaping (Constants.TaskStatus) -> Void, numEpochs: Int, - task: flwr.MLTask + task: MLTask ) { let dataset: MLBatchProvider let configuration = MLModelConfiguration()